#3062 - Remove discriminators from abstract classes and fix remaining old discriminator names

This commit is contained in:
Marek Wolan
2025-02-05 10:12:13 +00:00
parent 0a6b604afd
commit 4a472c5c75
28 changed files with 55 additions and 79 deletions

View File

@@ -12,8 +12,6 @@ from primaite.interface.request import RequestFormat
class AbstractAction(BaseModel, ABC):
"""Base class for actions."""
config: "AbstractAction.ConfigSchema"
class ConfigSchema(BaseModel, ABC):
"""Base configuration schema for Actions."""

View File

@@ -21,9 +21,7 @@ __all__ = (
class ACLAddRuleAbstractAction(AbstractAction, ABC):
"""Base abstract class for ACL add rule actions."""
config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Configuration Schema base for ACL add rule abstract actions."""
src_ip: Union[IPV4Address, Literal["ALL"]]
@@ -37,12 +35,10 @@ class ACLAddRuleAbstractAction(AbstractAction, ABC):
dst_wildcard: Union[IPV4Address, Literal["NONE"]]
class ACLRemoveRuleAbstractAction(AbstractAction, discriminator="acl-remove-rule-abstract-action"):
class ACLRemoveRuleAbstractAction(AbstractAction, ABC):
"""Base abstract class for acl remove rule actions."""
config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Configuration Schema base for ACL remove rule abstract actions."""
position: int

View File

@@ -23,9 +23,7 @@ class NodeApplicationAbstractAction(AbstractAction, ABC):
inherit from this base class.
"""
config: "NodeApplicationAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Base Configuration schema for Node Application actions."""
node_name: str

View File

@@ -24,9 +24,7 @@ class NodeFileAbstractAction(AbstractAction, ABC):
only three parameters can inherit from this base class.
"""
config: "NodeFileAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Configuration Schema for NodeFileAbstractAction."""
node_name: str

View File

@@ -22,9 +22,7 @@ class NodeFolderAbstractAction(AbstractAction, ABC):
this base class.
"""
config: "NodeFolderAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Base configuration schema for NodeFolder actions."""
node_name: str

View File

@@ -16,9 +16,7 @@ class HostNICAbstractAction(AbstractAction, ABC):
base class.
"""
config: "HostNICAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Base Configuration schema for HostNIC actions."""
node_name: str

View File

@@ -1,5 +1,6 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
@@ -8,12 +9,10 @@ from primaite.interface.request import RequestFormat
__all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction")
class NetworkPortAbstractAction(AbstractAction, discriminator="network-port-abstract"):
class NetworkPortAbstractAction(AbstractAction, ABC):
"""Base class for Network port actions."""
config: "NetworkPortAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Base configuration schema for NetworkPort actions."""
target_nodename: str

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import ClassVar, List, Optional, Union
from primaite.game.agent.actions.manager import AbstractAction
@@ -18,16 +18,14 @@ __all__ = (
)
class NodeAbstractAction(AbstractAction, discriminator="node-abstract"):
class NodeAbstractAction(AbstractAction, ABC):
"""
Abstract base class for node actions.
Any action which applies to a node and uses node_name as its only parameter can inherit from this base class.
"""
config: "NodeAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Base Configuration schema for Node actions."""
node_name: str
@@ -83,12 +81,10 @@ class NodeResetAction(NodeAbstractAction, discriminator="node-reset"):
verb: ClassVar[str] = "reset"
class NodeNMAPAbstractAction(AbstractAction, discriminator="node-nmap-abstract-action"):
class NodeNMAPAbstractAction(AbstractAction, ABC):
"""Base class for NodeNMAP actions."""
config: "NodeNMAPAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Base Configuration Schema for NodeNMAP actions."""
target_ip_address: Union[str, List[str]]

View File

@@ -1,4 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import ABC
from typing import ClassVar
from primaite.game.agent.actions.manager import AbstractAction
@@ -17,15 +18,13 @@ __all__ = (
)
class NodeServiceAbstractAction(AbstractAction, discriminator="node-service-abstract"):
class NodeServiceAbstractAction(AbstractAction, ABC):
"""Abstract Action for Node Service related actions.
Any actions which use node_name and service_name can inherit from this class.
"""
config: "NodeServiceAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
node_name: str
service_name: str
verb: ClassVar[str]

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
from abc import abstractmethod
from abc import ABC, abstractmethod
from primaite.game.agent.actions.manager import AbstractAction
from primaite.interface.request import RequestFormat
@@ -11,12 +11,10 @@ __all__ = (
)
class NodeSessionAbstractAction(AbstractAction, discriminator="node-session-abstract"):
class NodeSessionAbstractAction(AbstractAction, ABC):
"""Base class for NodeSession actions."""
config: "NodeSessionAbstractAction.ConfigSchema"
class ConfigSchema(AbstractAction.ConfigSchema):
class ConfigSchema(AbstractAction.ConfigSchema, ABC):
"""Base configuration schema for NodeSessionAbstractActions."""
node_name: str

View File

@@ -68,7 +68,7 @@ class AbstractAgent(BaseModel, ABC):
)
reward_function: RewardFunction.ConfigSchema = Field(default_factory=lambda: RewardFunction.ConfigSchema())
config: "AbstractAgent.ConfigSchema" = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
config: ConfigSchema = Field(default_factory=lambda: AbstractAgent.ConfigSchema())
logger: AgentLog = AgentLog(agent_name="Abstract_Agent")
history: List[AgentHistoryItem] = []
@@ -161,16 +161,16 @@ class AbstractAgent(BaseModel, ABC):
return agent_class(config=config)
class AbstractScriptedAgent(AbstractAgent, discriminator="abstract-scripted-agent"):
class AbstractScriptedAgent(AbstractAgent, ABC):
"""Base class for actors which generate their own behaviour."""
config: "AbstractScriptedAgent.ConfigSchema" = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema())
class ConfigSchema(AbstractAgent.ConfigSchema):
class ConfigSchema(AbstractAgent.ConfigSchema, ABC):
"""Configuration Schema for AbstractScriptedAgents."""
type: str = "AbstractScriptedAgent"
config: ConfigSchema = Field(default_factory=lambda: AbstractScriptedAgent.ConfigSchema())
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Return an action to be taken in the environment."""

View File

@@ -43,16 +43,16 @@ _LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
class AbstractReward(BaseModel):
class AbstractReward(BaseModel, ABC):
"""Base class for reward function components."""
config: "AbstractReward.ConfigSchema"
class ConfigSchema(BaseModel, ABC):
"""Config schema for AbstractReward."""
type: str = ""
config: ConfigSchema
_registry: ClassVar[Dict[str, Type["AbstractReward"]]] = {}
def __init_subclass__(cls, discriminator: Optional[str] = None, **kwargs: Any) -> None:

View File

@@ -2,7 +2,7 @@
from __future__ import annotations
import random
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from gymnasium.core import ObsType
@@ -13,18 +13,18 @@ from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
__all__ = "AbstractTAPAgent"
class AbstractTAPAgent(PeriodicAgent, discriminator="abstract-tap"):
class AbstractTAPAgent(PeriodicAgent, ABC):
"""Base class for TAP agents to inherit from."""
config: "AbstractTAPAgent.ConfigSchema" = Field(default_factory=lambda: AbstractTAPAgent.ConfigSchema())
next_execution_timestep: int = 0
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema):
class AgentSettingsSchema(PeriodicAgent.AgentSettingsSchema, ABC):
"""Schema for the `agent_settings` part of the agent config."""
possible_starting_nodes: List[str] = Field(default_factory=list)
class ConfigSchema(PeriodicAgent.ConfigSchema):
class ConfigSchema(PeriodicAgent.ConfigSchema, ABC):
"""Configuration schema for Abstract TAP agents."""
type: str = "abstract-tap"

View File

@@ -6,8 +6,6 @@ from pydantic import Field
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
__all__ = "DataManipulationAgent"
class DataManipulationAgent(PeriodicAgent, discriminator="red-database-corrupting-agent"):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""

View File

@@ -288,8 +288,8 @@ class PrimaiteGame:
if "folder_restore_duration" in defaults_config:
new_node.file_system._default_folder_restore_duration = defaults_config["folder_restore_duration"]
if "users" in node_cfg and new_node.software_manager.software.get("UserManager"):
user_manager: UserManager = new_node.software_manager.software["UserManager"] # noqa
if "users" in node_cfg and new_node.software_manager.software.get("user-manager"):
user_manager: UserManager = new_node.software_manager.software["user-manager"] # noqa
for user_cfg in node_cfg["users"]:
user_manager.add_user(**user_cfg, bypass_can_perform_action=True)

View File

@@ -170,13 +170,13 @@ def arcd_uc2_network() -> Network:
client_1.power_on()
network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1])
client_1.software_manager.install(DatabaseClient)
db_client_1: DatabaseClient = client_1.software_manager.software.get("DatabaseClient")
db_client_1: DatabaseClient = client_1.software_manager.software.get("database-client")
db_client_1.configure(server_ip_address=IPv4Address("192.168.1.14"))
db_client_1.run()
web_browser_1 = client_1.software_manager.software.get("web-browser")
web_browser_1.target_url = "http://arcd.com/users/"
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("data-manipulation-bot")
db_manipulation_bot.configure(
server_ip_address=IPv4Address("192.168.1.14"),
payload="DELETE",

View File

@@ -37,7 +37,7 @@ class DatabaseClientConnection(BaseModel):
@property
def client(self) -> Optional[DatabaseClient]:
"""The DatabaseClient that holds this connection."""
return self.parent_node.software_manager.software.get("DatabaseClient")
return self.parent_node.software_manager.software.get("database-client")
def query(self, sql: str) -> bool:
"""

View File

@@ -70,7 +70,7 @@ class NMAP(Application, discriminator="nmap"):
}
def __init__(self, **kwargs):
kwargs["name"] = "NMAP"
kwargs["name"] = "nmap"
kwargs["port"] = PORT_LOOKUP["NONE"]
kwargs["protocol"] = PROTOCOL_LOOKUP["NONE"]
super().__init__(**kwargs)

View File

@@ -678,8 +678,8 @@ def test_firewall_acl_add_remove_rule_integration():
assert firewall.external_outbound_acl.acl[1].action.name == "DENY"
assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10")
assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10")
assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"]
assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"]
assert firewall.external_outbound_acl.acl[1].dst_port is None
assert firewall.external_outbound_acl.acl[1].src_port is None
assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"]
env.step(12) # Remove ACL rule from External Outbound

View File

@@ -27,7 +27,7 @@ class BroadcastTestService(Service, discriminator="broadcast-test-service"):
def __init__(self, **kwargs):
# Set default service properties for broadcasting
kwargs["name"] = "BroadcastService"
kwargs["name"] = "broadcast-test-service"
kwargs["port"] = PORT_LOOKUP["HTTP"]
kwargs["protocol"] = PROTOCOL_LOOKUP["TCP"]
super().__init__(**kwargs)
@@ -127,7 +127,7 @@ def broadcast_network() -> Network:
server_1.power_on()
server_1.software_manager.install(BroadcastTestService)
service: BroadcastTestService = server_1.software_manager.software["BroadcastService"]
service: BroadcastTestService = server_1.software_manager.software["broadcast-test-service"]
service.start()
switch_1: Switch = Switch.from_config(
@@ -153,7 +153,7 @@ def broadcast_service_and_clients(
"broadcast-test-client"
]
service: BroadcastTestService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
"broadcast-service"
"broadcast-test-service"
]
return service, client_1, client_2

View File

@@ -22,7 +22,7 @@ class _DatabaseListener(Service, discriminator="database-listener"):
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
config: "_DatabaseListener.ConfigSchema" = Field(default_factory=lambda: _DatabaseListener.ConfigSchema())
name: str = "DatabaseListener"
name: str = "database-listener"
protocol: str = PROTOCOL_LOOKUP["TCP"]
port: int = PORT_LOOKUP["NONE"]
listen_on_ports: Set[int] = {PORT_LOOKUP["POSTGRES_SERVER"]}
@@ -45,7 +45,7 @@ def test_http_listener(client_server):
server_db.start()
server.software_manager.install(_DatabaseListener)
server_db_listener: _DatabaseListener = server.software_manager.software["DatabaseListener"]
server_db_listener: _DatabaseListener = server.software_manager.software["database-listener"]
server_db_listener.start()
computer.software_manager.install(DatabaseClient)

View File

@@ -46,7 +46,7 @@ def test_create_web_client():
computer.power_on()
# Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software.get("web-browser")
assert web_browser.name is "web-browser"
assert web_browser.name == "web-browser"
assert web_browser.port is PORT_LOOKUP["HTTP"]
assert web_browser.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -29,7 +29,7 @@ def dns_client() -> Computer:
def test_create_dns_client(dns_client):
assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.software.get("dns-client")
assert dns_client_service.name is "dns-client"
assert dns_client_service.name == "dns-client"
assert dns_client_service.port is PORT_LOOKUP["DNS"]
assert dns_client_service.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -33,7 +33,7 @@ def dns_server() -> Node:
def test_create_dns_server(dns_server):
assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.software.get("dns-server")
assert dns_server_service.name is "dns-server"
assert dns_server_service.name == "dns-server"
assert dns_server_service.port is PORT_LOOKUP["DNS"]
assert dns_server_service.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -32,7 +32,7 @@ def ftp_client() -> Node:
def test_create_ftp_client(ftp_client):
assert ftp_client is not None
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("ftp-client")
assert ftp_client_service.name is "ftp-client"
assert ftp_client_service.name == "ftp-client"
assert ftp_client_service.port is PORT_LOOKUP["FTP"]
assert ftp_client_service.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -31,7 +31,7 @@ def ftp_server() -> Node:
def test_create_ftp_server(ftp_server):
assert ftp_server is not None
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("ftp-server")
assert ftp_server_service.name is "ftp-server"
assert ftp_server_service.name == "ftp-server"
assert ftp_server_service.port is PORT_LOOKUP["FTP"]
assert ftp_server_service.protocol is PROTOCOL_LOOKUP["TCP"]

View File

@@ -364,7 +364,7 @@ def test_SSH_across_network():
pc_a = network.get_node_by_hostname("client_1")
router_1 = network.get_node_by_hostname("router_1")
terminal_a: Terminal = pc_a.software_manager.software.get("Terminal")
terminal_a: Terminal = pc_a.software_manager.software.get("terminal")
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=PORT_LOOKUP["SSH"], dst_port=PORT_LOOKUP["SSH"], position=21

View File

@@ -34,9 +34,9 @@ def web_server() -> Server:
def test_create_web_server(web_server):
assert web_server is not None
web_server_service: WebServer = web_server.software_manager.software.get("web-server")
assert web_server_service.name is "web-server"
assert web_server_service.port is PORT_LOOKUP["HTTP"]
assert web_server_service.protocol is PROTOCOL_LOOKUP["TCP"]
assert web_server_service.name == "web-server"
assert web_server_service.port == PORT_LOOKUP["HTTP"]
assert web_server_service.protocol == PROTOCOL_LOOKUP["TCP"]
def test_handling_get_request_not_found_path(web_server):