diff --git a/src/primaite/game/agent/actions/__init__.py b/src/primaite/game/agent/actions/__init__.py index 625725fe..7f054591 100644 --- a/src/primaite/game/agent/actions/__init__.py +++ b/src/primaite/game/agent/actions/__init__.py @@ -14,6 +14,7 @@ from primaite.game.agent.actions import ( service, session, ) +from primaite.game.agent.actions.manager import ActionManager __all__ = ( "abstract", diff --git a/src/primaite/game/agent/actions/abstract.py b/src/primaite/game/agent/actions/abstract.py index 9c13cc73..5c0594fd 100644 --- a/src/primaite/game/agent/actions/abstract.py +++ b/src/primaite/game/agent/actions/abstract.py @@ -11,6 +11,8 @@ from primaite.interface.request import RequestFormat class AbstractAction(BaseModel): """Base class for actions.""" + config: "AbstractAction.ConfigSchema" + class ConfigSchema(BaseModel, ABC): """Base configuration schema for Actions.""" @@ -33,6 +35,8 @@ class AbstractAction(BaseModel): @classmethod def from_config(cls, config: Dict) -> "AbstractAction": """Create an action component from a config dictionary.""" - for attribute, value in config.items(): - setattr(cls.ConfigSchema, attribute, value) - return cls + if not config.get("type"): + config.update({"type": cls.__name__}) + print("oooh") + print(config) + return cls(config=cls.ConfigSchema(**config)) diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py index 3beface9..11269a7e 100644 --- a/src/primaite/game/agent/actions/acl.py +++ b/src/primaite/game/agent/actions/acl.py @@ -25,27 +25,21 @@ class ACLAbstractAction(AbstractAction, identifier="acl_abstract_action"): class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): """Action which adds a rule to a router's ACL.""" - target_router: str - position: int - permission: Literal[1, 2] - source_wildcard_id: int - source_port: str - dst_ip: str - dst_wildcard: int - dst_port: int + config: "RouterACLAddRuleAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema): """Configuration Schema for RouterACLAddRuleAction.""" target_router: str + permission: str + protocol_name: str position: int - permission: Literal[1, 2] src_ip: str src_wildcard: int source_port: str dst_ip: str dst_wildcard: int - dst_port: int + dst_port: str @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: @@ -71,11 +65,13 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_rule"): """Action which removes a rule from a router's ACL.""" + config: "RouterACLRemoveRuleAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for RouterACLRemoveRuleAction.""" target_router: str - position: str + position: int @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: @@ -86,33 +82,42 @@ class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_ru class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_rule"): """Action which adds a rule to a firewall port's ACL.""" - max_acl_rules: int - num_ips: int - num_ports: int - num_protocols: int - num_permissions: int = 3 - permission: str - target_firewall_nodename: str - src_ip: str - dst_ip: str - dst_wildcard: str - src_port: Union[int| None] - dst_port: Union[int | None] + config: "FirewallACLAddRuleAction.ConfigSchema" + + # max_acl_rules: int + # num_ips: int + # num_ports: int + # num_protocols: int + # num_permissions: int = 3 + # permission: str + # target_firewall_nodename: str + # src_ip: str + # dst_ip: str + # dst_wildcard: str + # src_port: Union[int| None] + # dst_port: Union[int | None] class ConfigSchema(ACLAbstractAction.ConfigSchema): """Configuration schema for FirewallACLAddRuleAction.""" - max_acl_rules: int - num_ips: int - num_ports: int - num_protocols: int - num_permissions: int = 3 - permission: str target_firewall_nodename: str + firewall_port_name: str + firewall_port_direction: str + position: int + permission: str src_ip: str - dst_ip: str - dst_wildcard: str - src_port: Union[int| None] + dest_ip: str + src_port: str + dst_port: str + protocol_name: str + source_wildcard_id: int + dest_wildcard_id: int + + # max_acl_rules: int + # num_ips: int + # num_ports: int + # num_protocols: int + # num_permissions: int = 3 @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: @@ -136,10 +141,10 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r config.permission, config.protocol_name, config.src_ip, - config.src_wildcard, + config.source_wildcard_id, config.src_port, - config.dst_ip, - config.dst_wildcard, + config.dest_ip, + config.dest_wildcard_id, config.dst_port, config.position, ] @@ -148,6 +153,8 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r class FirewallACLRemoveRuleAction(AbstractAction, identifier="firewall_acl_remove_rule"): """Action which removes a rule from a firewall port's ACL.""" + config:"FirewallACLRemoveRuleAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for FirewallACLRemoveRuleAction.""" diff --git a/src/primaite/game/agent/actions/application.py b/src/primaite/game/agent/actions/application.py index 942ebe90..f515a8ec 100644 --- a/src/primaite/game/agent/actions/application.py +++ b/src/primaite/game/agent/actions/application.py @@ -22,13 +22,14 @@ class NodeApplicationAbstractAction(AbstractAction, identifier="node_application inherit from this base class. """ + config: "NodeApplicationAbstractAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Base Configuration schema for Node Application actions.""" node_name: str application_name: str - - verb: ClassVar[str] + verb: ClassVar[str] @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: @@ -41,14 +42,14 @@ class NodeApplicationAbstractAction(AbstractAction, identifier="node_application config.node_name, "application", config.application_name, - cls.model_fields["verb"].default, + config.verb, ] class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="node_application_execute"): """Action which executes an application.""" - verb: str = "execute" + config: "NodeApplicationExecuteAction.ConfigSchema" class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): """Configuration schema for NodeApplicationExecuteAction.""" @@ -59,7 +60,7 @@ class NodeApplicationExecuteAction(NodeApplicationAbstractAction, identifier="no class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_application_scan"): """Action which scans an application.""" - verb: str = "scan" + config: "NodeApplicationScanAction.ConfigSchema" class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): """Configuration schema for NodeApplicationScanAction.""" @@ -70,7 +71,7 @@ class NodeApplicationScanAction(NodeApplicationAbstractAction, identifier="node_ class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node_application_close"): """Action which closes an application.""" - verb: str = "close" + config: "NodeApplicationCloseAction.ConfigSchema" class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): """Configuration schema for NodeApplicationCloseAction.""" @@ -81,7 +82,7 @@ class NodeApplicationCloseAction(NodeApplicationAbstractAction, identifier="node class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_application_fix"): """Action which fixes an application.""" - verb: str = "fix" + config: "NodeApplicationFixAction.ConfigSchema" class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): """Configuration schema for NodeApplicationFixAction.""" @@ -92,7 +93,7 @@ class NodeApplicationFixAction(NodeApplicationAbstractAction, identifier="node_a class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="node_application_install"): """Action which installs an application.""" - verb: str = "install" + config: "NodeApplicationInstallAction.ConfigSchema" class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): """Configuration schema for NodeApplicationInstallAction.""" @@ -110,7 +111,7 @@ class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="no config.node_name, "software_manager", "application", - cls.model_fields["verb"].default, + config.verb, config.application_name, ] @@ -118,7 +119,7 @@ class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="no class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="node_application_remove"): """Action which removes/uninstalls an application.""" - verb: str = "uninstall" + config: "NodeApplicationRemoveAction.ConfigSchema" class ConfigSchema(NodeApplicationAbstractAction.ConfigSchema): """Configuration schema for NodeApplicationRemoveAction.""" @@ -136,6 +137,6 @@ class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="nod config.node_name, "software_manager", "application", - cls.model_fields["verb"].default, + config.verb, config.application_name, ] diff --git a/src/primaite/game/agent/actions/config.py b/src/primaite/game/agent/actions/config.py index dc7e98b9..da9f77e6 100644 --- a/src/primaite/game/agent/actions/config.py +++ b/src/primaite/game/agent/actions/config.py @@ -22,6 +22,8 @@ __all__ = ( class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_ransomware_configure"): """Action which sets config parameters for a ransomware script on a node.""" + config: "ConfigureRansomwareScriptAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for ConfigureRansomwareScriptAction.""" @@ -36,16 +38,18 @@ class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_rans """Return the action formatted as a request that can be ingested by the simulation.""" if config.node_name is None: return ["do_nothing"] - ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", config.model_config] class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"): """Action which sets config parameters for a DoS bot on a node.""" - class _Opts(BaseModel): + config: "ConfigureDoSBotAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): """Schema for options that can be passed to this action.""" + node_name: str model_config = ConfigDict(extra="forbid") target_ip_address: Optional[str] = None target_port: Optional[str] = None @@ -58,18 +62,19 @@ class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"): def __init__(self, manager: "ActionManager", **kwargs) -> None: super().__init__(manager=manager) - def form_request(self, node_id: int, config: Dict) -> RequestFormat: + def form_request(self, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: + if config.node_name is None: return ["do_nothing"] - self._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", node_name, "application", "DoSBot", "configure", config] + self.ConfigSchema.model_validate(config) # check that options adhere to schema + return ["network", "node", config.node_name, "application", "DoSBot", "configure", config] class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): """Action which configures a C2 Beacon based on the parameters given.""" + config: "ConfigureC2BeaconAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for ConfigureC2BeaconAction.""" @@ -79,14 +84,6 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): masquerade_protocol: str = Field(default="TCP") masquerade_port: str = Field(default="HTTP") - class _Opts(BaseModel): - """Schema for options that can be passed to this action.""" - - c2_server_ip_address: str - keep_alive_frequency: int = Field(default=5, ge=1) - masquerade_protocol: str = Field(default="TCP") - masquerade_port: str = Field(default="HTTP") - @field_validator( "c2_server_ip_address", "keep_alive_frequency", @@ -106,21 +103,23 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): """Return the action formatted as a request that can be ingested by the simulation.""" if config.node_name is None: return ["do_nothing"] - configuration = ConfigureC2BeaconAction._Opts( + configuration = ConfigureC2BeaconAction.ConfigSchema( c2_server_ip_address=config.c2_server_ip_address, keep_alive_frequency=config.keep_alive_frequency, masquerade_port=config.masquerade_port, masquerade_protocol=config.masquerade_protocol, ) - ConfigureC2BeaconAction._Opts.model_validate(configuration) # check that options adhere to schema + ConfigureC2BeaconAction.ConfigSchema.model_validate(configuration) # check that options adhere to schema - return ["network", "node", config.node_name, "application", "C2Beacon", "configure", config.__dict__] + return ["network", "node", config.node_name, "application", "C2Beacon", "configure", configuration] class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_command"): """Action which sends a terminal command to a remote node via SSH.""" + config: "NodeSendRemoteCommandAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for NodeSendRemoteCommandAction.""" @@ -146,37 +145,37 @@ class NodeSendRemoteCommandAction(AbstractAction, identifier="node_send_remote_c class TerminalC2ServerAction(AbstractAction, identifier="c2_server_terminal_command"): """Action which causes the C2 Server to send a command to the C2 Beacon to execute the terminal command passed.""" - class _Opts(BaseModel): + config: "TerminalC2ServerAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): """Schema for options that can be passed to this action.""" + node_name: str commands: Union[List[RequestFormat], RequestFormat] ip_address: Optional[str] username: Optional[str] password: Optional[str] - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, node_id: int, commands: List, ip_address: Optional[str], account: dict) -> RequestFormat: + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: + if config.node_name is None: return ["do_nothing"] command_model = { - "commands": commands, - "ip_address": ip_address, - "username": account["username"], - "password": account["password"], + "commands": config.commands, + "ip_address": config.ip_address, + "username": config.username, + "password": config.password, } - - TerminalC2ServerAction._Opts.model_validate(command_model) - return ["network", "node", node_name, "application", "C2Server", "terminal_command", command_model] + return ["network", "node", config.node_name, "application", "C2Server", "terminal_command", command_model] class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ransomware_launch"): """Action which causes the C2 Server to send a command to the C2 Beacon to launch the RansomwareScript.""" + config: "RansomwareLaunchC2ServerAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for RansomwareLaunchC2ServerAction.""" @@ -194,9 +193,12 @@ class RansomwareLaunchC2ServerAction(AbstractAction, identifier="c2_server_ranso class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfiltrate"): """Action which exfiltrates a target file from a certain node onto the C2 beacon and then the C2 Server.""" - class _Opts(BaseModel): + config: "ExfiltrationC2ServerAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): """Schema for options that can be passed to this action.""" + node_name: str username: Optional[str] password: Optional[str] target_ip_address: str @@ -204,40 +206,30 @@ class ExfiltrationC2ServerAction(AbstractAction, identifier="c2_server_data_exfi target_folder_name: str exfiltration_folder_name: Optional[str] - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - + @classmethod def form_request( - self, - node_id: int, - account: dict, - target_ip_address: str, - target_file_name: str, - target_folder_name: str, - exfiltration_folder_name: Optional[str], + cls, + config: ConfigSchema ) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" - node_name = self.manager.get_node_name_by_idx(node_id) - if node_name is None: + if config.node_name is None: return ["do_nothing"] command_model = { - "target_file_name": target_file_name, - "target_folder_name": target_folder_name, - "exfiltration_folder_name": exfiltration_folder_name, - "target_ip_address": target_ip_address, - "username": account["username"], - "password": account["password"], + "target_file_name": config.target_file_name, + "target_folder_name": config.target_folder_name, + "exfiltration_folder_name": config.exfiltration_folder_name, + "target_ip_address": config.target_ip_address, + "username": config.username, + "password": config.password, } - ExfiltrationC2ServerAction._Opts.model_validate(command_model) - return ["network", "node", node_name, "application", "C2Server", "exfiltrate", command_model] + return ["network", "node", config.node_name, "application", "C2Server", "exfiltrate", command_model] class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_database_client"): """Action which sets config parameters for a database client on a node.""" - node_name: str - model_config: ConfigDict = ConfigDict(extra="forbid") + config: "ConfigureDatabaseClientAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema): """Schema for options that can be passed to this action.""" @@ -245,10 +237,8 @@ class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_databa node_name: str model_config = ConfigDict(extra="forbid") - def __init__(self, manager: "ActionManager", **kwargs) -> None: - super().__init__(manager=manager) - - def form_request(self, config: ConfigSchema) -> RequestFormat: + @classmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" if config.node_name is None: return ["do_nothing"] diff --git a/src/primaite/game/agent/actions/file.py b/src/primaite/game/agent/actions/file.py index 5d12b27a..b5e47c8a 100644 --- a/src/primaite/game/agent/actions/file.py +++ b/src/primaite/game/agent/actions/file.py @@ -23,14 +23,15 @@ class NodeFileAbstractAction(AbstractAction, identifier="node_file_abstract_acti only three parameters can inherit from this base class. """ + config: "NodeFileAbstractAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration Schema for NodeFileAbstractAction.""" node_name: str folder_name: str file_name: str - - verb: ClassVar[str] + verb: ClassVar[str] @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: @@ -46,15 +47,14 @@ class NodeFileAbstractAction(AbstractAction, identifier="node_file_abstract_acti config.folder_name, "file", config.file_name, - cls.model_fields["verb"].default, + config.verb, ] class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create"): """Action which creates a new file in a given folder.""" - verb: str = "create" - force: bool = False + config: "NodeFileCreateAction.ConfigSchema" class ConfigSchema(NodeFileAbstractAction.ConfigSchema): """Configuration schema for NodeFileCreateAction.""" @@ -72,18 +72,18 @@ class NodeFileCreateAction(NodeFileAbstractAction, identifier="node_file_create" "node", config.node_name, "file_system", - cls.model_fields["verb"].default, + config.verb, "file", config.folder_name, config.file_name, - cls.model_fields["force"].default, + config.verb, ] class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"): """Action which scans a file.""" - verb: str = "scan" + config: "NodeFileScanAction.ConfigSchema" class ConfigSchema(NodeFileAbstractAction.ConfigSchema): """Configuration schema for NodeFileScanAction.""" @@ -94,7 +94,7 @@ class NodeFileScanAction(NodeFileAbstractAction, identifier="node_file_scan"): class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete"): """Action which deletes a file.""" - verb: str = "delete" + config: "NodeFileDeleteAction.ConfigSchema" class ConfigSchema(NodeFileAbstractAction.ConfigSchema): """Configuration schema for NodeFileDeleteAction.""" @@ -111,7 +111,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete" "node", config.node_name, "file_system", - cls.model_fields["verb"].default, + config.verb, "file", config.folder_name, config.file_name, @@ -121,7 +121,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction, identifier="node_file_delete" class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restore"): """Action which restores a file.""" - verb: str = "restore" + config: "NodeFileRestoreAction.ConfigSchema" class ConfigSchema(NodeFileAbstractAction.ConfigSchema): """Configuration schema for NodeFileRestoreAction.""" @@ -132,7 +132,7 @@ class NodeFileRestoreAction(NodeFileAbstractAction, identifier="node_file_restor class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrupt"): """Action which corrupts a file.""" - verb: str = "corrupt" + config: "NodeFileCorruptAction.ConfigSchema" class ConfigSchema(NodeFileAbstractAction.ConfigSchema): """Configuration schema for NodeFileCorruptAction.""" @@ -143,7 +143,7 @@ class NodeFileCorruptAction(NodeFileAbstractAction, identifier="node_file_corrup class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access"): """Action which increases a file's access count.""" - verb: str = "access" + config: "NodeFileAccessAction.ConfigSchema" class ConfigSchema(NodeFileAbstractAction.ConfigSchema): """Configuration schema for NodeFileAccessAction.""" @@ -160,7 +160,7 @@ class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access" "node", config.node_name, "file_system", - cls.model_fields["verb"].default, + config.verb, config.folder_name, config.file_name, ] @@ -169,7 +169,7 @@ class NodeFileAccessAction(NodeFileAbstractAction, identifier="node_file_access" class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_checkhash"): """Action which checks the hash of a file.""" - verb: str = "checkhash" + config: "NodeFileCheckhashAction.ConfigSchema" class ConfigSchema(NodeFileAbstractAction.ConfigSchema): """Configuration schema for NodeFileCheckhashAction.""" @@ -180,7 +180,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_chec class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"): """Action which repairs a file.""" - verb: str = "repair" + config: "NodeFileRepairAction.ConfigSchema" class ConfigSchema(NodeFileAbstractAction.ConfigSchema): """Configuration Schema for NodeFileRepairAction.""" diff --git a/src/primaite/game/agent/actions/folder.py b/src/primaite/game/agent/actions/folder.py index e430efb7..a27ca89b 100644 --- a/src/primaite/game/agent/actions/folder.py +++ b/src/primaite/game/agent/actions/folder.py @@ -21,13 +21,14 @@ class NodeFolderAbstractAction(AbstractAction, identifier="node_folder_abstract" this base class. """ + config: "NodeFolderAbstractAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Base configuration schema for NodeFolder actions.""" node_name: str folder_name: str - - verb: ClassVar[str] + verb: ClassVar[str] @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: @@ -41,14 +42,14 @@ class NodeFolderAbstractAction(AbstractAction, identifier="node_folder_abstract" "file_system", "folder", config.folder_name, - cls.model_fields["verb"].default, + config.verb, ] class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_scan"): """Action which scans a folder.""" - verb: str = "scan" + config: "NodeFolderScanAction.ConfigSchema" class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): """Configuration schema for NodeFolderScanAction.""" @@ -59,7 +60,7 @@ class NodeFolderScanAction(NodeFolderAbstractAction, identifier="node_folder_sca class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folder_checkhash"): """Action which checks the hash of a folder.""" - verb: str = "checkhash" + config: "NodeFolderCheckhashAction.ConfigSchema" class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): """Configuration schema for NodeFolderCheckhashAction.""" @@ -70,7 +71,7 @@ class NodeFolderCheckhashAction(NodeFolderAbstractAction, identifier="node_folde class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_repair"): """Action which repairs a folder.""" - verb: str = "repair" + config: "NodeFolderRepairAction.ConfigSchema" class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): """Configuration schema for NodeFolderRepairAction.""" @@ -81,7 +82,7 @@ class NodeFolderRepairAction(NodeFolderAbstractAction, identifier="node_folder_r class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_restore"): """Action which restores a folder.""" - verb: str = "restore" + config: "NodeFolderRestoreAction.ConfigSchema" class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): """Configuration schema for NodeFolderRestoreAction.""" @@ -92,7 +93,7 @@ class NodeFolderRestoreAction(NodeFolderAbstractAction, identifier="node_folder_ class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_create"): """Action which creates a new folder.""" - verb: str = "create" + config: "NodeFolderCreateAction.ConfigSchema" class ConfigSchema(NodeFolderAbstractAction.ConfigSchema): """Configuration schema for NodeFolderCreateAction.""" @@ -109,7 +110,7 @@ class NodeFolderCreateAction(NodeFolderAbstractAction, identifier="node_folder_c "node", config.node_name, "file_system", - cls.model_fields["verb"].default, + config.verb, "folder", config.folder_name, ] diff --git a/src/primaite/game/agent/actions/host_nic.py b/src/primaite/game/agent/actions/host_nic.py index 1ad2e52f..6df241bc 100644 --- a/src/primaite/game/agent/actions/host_nic.py +++ b/src/primaite/game/agent/actions/host_nic.py @@ -1,4 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from typing import ClassVar from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat @@ -13,14 +14,14 @@ class HostNICAbstractAction(AbstractAction, identifier="host_nic_abstract"): class. """ - node_name: str - nic_num: str + config: "HostNICAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema): """Base Configuration schema for HostNIC actions.""" node_name: str - nic_num: str + nic_num: int + verb: ClassVar[str] @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: @@ -33,14 +34,14 @@ class HostNICAbstractAction(AbstractAction, identifier="host_nic_abstract"): config.node_name, "network_interface", config.nic_num, - cls.model_fields["verb"].default, + config.verb, ] class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"): """Action which enables a NIC.""" - verb: str = "enable" + config: "HostNICEnableAction.ConfigSchema" class ConfigSchema(HostNICAbstractAction.ConfigSchema): """Configuration schema for HostNICEnableAction.""" @@ -51,7 +52,7 @@ class HostNICEnableAction(HostNICAbstractAction, identifier="host_nic_enable"): class HostNICDisableAction(HostNICAbstractAction, identifier="host_nic_disable"): """Action which disables a NIC.""" - verb: str = "disable" + config: "HostNICDisableAction.ConfigSchema" class ConfigSchema(HostNICAbstractAction.ConfigSchema): """Configuration schema for HostNICDisableAction.""" diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py index 3795d21d..9ef94069 100644 --- a/src/primaite/game/agent/actions/manager.py +++ b/src/primaite/game/agent/actions/manager.py @@ -31,7 +31,8 @@ class DoNothingAction(AbstractAction, identifier="do_nothing"): class ConfigSchema(AbstractAction.ConfigSchema): """Configuration Schema for DoNothingAction.""" - type: Literal["do_nothing"] = "do_nothing" + # type: Literal["do_nothing"] = "do_nothing" + type: str = "do_nothing" @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: @@ -126,13 +127,13 @@ class ActionManager: def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" act_obj = self.actions[action_identifier].from_config(config=action_options) - return act_obj.form_request(config=act_obj.ConfigSchema) + return act_obj.form_request(config=act_obj.config) @property def space(self) -> spaces.Space: """Return the gymnasium action space for this agent.""" return spaces.Discrete(len(self.action_map)) - + @classmethod def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": """ diff --git a/src/primaite/game/agent/actions/network.py b/src/primaite/game/agent/actions/network.py index af3793a2..346da9b7 100644 --- a/src/primaite/game/agent/actions/network.py +++ b/src/primaite/game/agent/actions/network.py @@ -11,13 +11,14 @@ __all__ = ("NetworkPortEnableAction", "NetworkPortDisableAction") class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstract"): """Base class for Network port actions.""" + config: "NetworkPortAbstractAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Base configuration schema for NetworkPort actions.""" target_nodename: str - port_id: str - - verb: ClassVar[str] + port_id: int + verb: ClassVar[str] @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: @@ -30,16 +31,16 @@ class NetworkPortAbstractAction(AbstractAction, identifier="network_port_abstrac config.target_nodename, "network_interface", config.port_id, - cls.model_fields["verb"].default, + config.verb, ] class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_port_enable"): """Action which enables are port on a router or a firewall.""" - verb: str = "enable" + config: "NetworkPortEnableAction.ConfigSchema" - class ConfigSchema(AbstractAction.ConfigSchema): + class ConfigSchema(NetworkPortAbstractAction.ConfigSchema): """Configuration schema for NetworkPortEnableAction.""" verb: str = "enable" @@ -48,9 +49,9 @@ class NetworkPortEnableAction(NetworkPortAbstractAction, identifier="network_por class NetworkPortDisableAction(NetworkPortAbstractAction, identifier="network_port_disable"): """Action which disables are port on a router or a firewall.""" - verb: str = "disable" + config: "NetworkPortDisableAction.ConfigSchema" - class ConfigSchema(AbstractAction.ConfigSchema): + class ConfigSchema(NetworkPortAbstractAction.ConfigSchema): """Configuration schema for NetworkPortDisableAction.""" verb: str = "disable" diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py index a69a8a5f..3c70d495 100644 --- a/src/primaite/game/agent/actions/node.py +++ b/src/primaite/game/agent/actions/node.py @@ -23,22 +23,25 @@ class NodeAbstractAction(AbstractAction, identifier="node_abstract"): Any action which applies to a node and uses node_name as its only parameter can inherit from this base class. """ + config: "NodeAbstractAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Base Configuration schema for Node actions.""" node_name: str - - verb: ClassVar[str] + verb: ClassVar[str] @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", config.node_name, cls.verb] + return ["network", "node", config.node_name, cls.config.verb] class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"): """Action which scans a node's OS.""" + config: "NodeOSScanAction.ConfigSchema" + class ConfigSchema(NodeAbstractAction.ConfigSchema): """Configuration schema for NodeOSScanAction.""" @@ -48,6 +51,8 @@ class NodeOSScanAction(NodeAbstractAction, identifier="node_os_scan"): class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"): """Action which shuts down a node.""" + config: "NodeShutdownAction.ConfigSchema" + class ConfigSchema(NodeAbstractAction.ConfigSchema): """Configuration schema for NodeShutdownAction.""" @@ -57,6 +62,8 @@ class NodeShutdownAction(NodeAbstractAction, identifier="node_shutdown"): class NodeStartupAction(NodeAbstractAction, identifier="node_startup"): """Action which starts up a node.""" + config: "NodeStartupAction.ConfigSchema" + class ConfigSchema(NodeAbstractAction.ConfigSchema): """Configuration schema for NodeStartupAction.""" @@ -66,6 +73,8 @@ class NodeStartupAction(NodeAbstractAction, identifier="node_startup"): class NodeResetAction(NodeAbstractAction, identifier="node_reset"): """Action which resets a node.""" + config: "NodeResetAction.ConfigSchema" + class ConfigSchema(NodeAbstractAction.ConfigSchema): """Configuration schema for NodeResetAction.""" @@ -75,22 +84,28 @@ class NodeResetAction(NodeAbstractAction, identifier="node_reset"): class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"): """Base class for NodeNMAP actions.""" + config: "NodeNMAPAbstractAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Base Configuration Schema for NodeNMAP actions.""" target_ip_address: Union[str, List[str]] - show: bool = False + show: bool = False node_name: str @classmethod @abstractmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: + # NMAP action requests don't share a common format for their requests + # This is just a placeholder to ensure the method is defined. pass class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"): """Action which performs an NMAP ping scan.""" + config: "NodeNMAPPingScanAction.ConfigSchema" + class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): """Configuration schema for NodeNMAPPingScanAction.""" @@ -113,6 +128,8 @@ class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"): """Action which performs an NMAP port scan.""" + config: "NodeNMAPPortScanAction.ConfigSchema" + class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): """Configuration Schema for NodeNMAPPortScanAction.""" @@ -146,6 +163,8 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"): """Action which performs an NMAP network service recon (ping scan followed by port scan).""" + config: "NodeNetworkServiceReconAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for NodeNetworkServiceReconAction.""" diff --git a/src/primaite/game/agent/actions/service.py b/src/primaite/game/agent/actions/service.py index dbdd57d3..7ccffb0a 100644 --- a/src/primaite/game/agent/actions/service.py +++ b/src/primaite/game/agent/actions/service.py @@ -23,22 +23,23 @@ class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstrac Any actions which use node_name and service_name can inherit from this class. """ + config: "NodeServiceAbstractAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): node_name: str service_name: str - - verb: ClassVar[str] + verb: ClassVar[str] @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - return ["network", "node", config.node_name, "service", config.service_name, cls.model_fields["verb"].default] + return ["network", "node", config.node_name, "service", config.service_name, config.verb] class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_scan"): """Action which scans a service.""" - verb: str = "scan" + config: "NodeServiceScanAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServiceScanAction.""" @@ -49,7 +50,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_ class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_stop"): """Action which stops a service.""" - verb: str = "stop" + config: "NodeServiceStopAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServiceStopAction.""" @@ -60,7 +61,7 @@ class NodeServiceStopAction(NodeServiceAbstractAction, identifier="node_service_ class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service_start"): """Action which starts a service.""" - verb: str = "start" + config: "NodeServiceStartAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServiceStartAction.""" @@ -71,7 +72,7 @@ class NodeServiceStartAction(NodeServiceAbstractAction, identifier="node_service class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service_pause"): """Action which pauses a service.""" - verb: str = "pause" + config: "NodeServicePauseAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServicePauseAction.""" @@ -82,7 +83,7 @@ class NodeServicePauseAction(NodeServiceAbstractAction, identifier="node_service class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_service_resume"): """Action which resumes a service.""" - verb: str = "resume" + config: "NodeServiceResumeAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServiceResumeAction.""" @@ -93,7 +94,7 @@ class NodeServiceResumeAction(NodeServiceAbstractAction, identifier="node_servic class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_service_restart"): """Action which restarts a service.""" - verb: str = "restart" + config: "NodeServiceRestartAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServiceRestartAction.""" @@ -104,7 +105,7 @@ class NodeServiceRestartAction(NodeServiceAbstractAction, identifier="node_servi class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_service_disable"): """Action which disables a service.""" - verb: str = "disable" + config: "NodeServiceDisableAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServiceDisableAction.""" @@ -115,7 +116,7 @@ class NodeServiceDisableAction(NodeServiceAbstractAction, identifier="node_servi class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_service_enable"): """Action which enables a service.""" - verb: str = "enable" + config: "NodeServiceEnableAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServiceEnableAction.""" @@ -126,7 +127,7 @@ class NodeServiceEnableAction(NodeServiceAbstractAction, identifier="node_servic class NodeServiceFixAction(NodeServiceAbstractAction, identifier="node_service_fix"): """Action which fixes a service.""" - verb: str = "fix" + config: "NodeServiceFixAction.ConfigSchema" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): """Configuration Schema for NodeServiceFixAction.""" diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py index dcae8b47..a0805a49 100644 --- a/src/primaite/game/agent/actions/session.py +++ b/src/primaite/game/agent/actions/session.py @@ -14,6 +14,8 @@ __all__ = ( class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstract"): """Base class for NodeSession actions.""" + config: "NodeSessionAbstractAction.ConfigSchema" + class ConfigSchema(AbstractAction.ConfigSchema): """Base configuration schema for NodeSessionAbstractActions.""" @@ -34,8 +36,7 @@ class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstrac class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_session_remote_login"): """Action which performs a remote session login.""" - username: str - password: str + config: "NodeSessionsRemoteLoginAction.ConfigSchema" class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeSessionsRemoteLoginAction.""" @@ -64,6 +65,8 @@ class NodeSessionsRemoteLoginAction(NodeSessionAbstractAction, identifier="node_ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node_session_remote_logoff"): """Action which performs a remote session logout.""" + config: "NodeSessionsRemoteLogoutAction.ConfigSchema" + class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeSessionsRemoteLogoutAction.""" @@ -80,9 +83,7 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"): """Action which changes the password for a user.""" - username: str - current_password: str - new_password: str + config: "NodeAccountChangePasswordAction.ConfigSchema" class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeAccountsChangePasswordAction.""" @@ -103,5 +104,5 @@ class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="nod "change_password", config.username, config.current_password, - cls.new_password, + config.new_password, ] diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index a2b75be5..4c3b5000 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -114,12 +114,12 @@ agents: position: 1 permission: PERMIT src_ip: 192.168.0.10 - dst_ip: ALL + dest_ip: ALL src_port: ALL dst_port: ALL protocol_name: ALL - src_wildcard: 0 - dst_wildcard: 0 + source_wildcard_id: 0 + dest_wildcard_id: 0 2: action: firewall_acl_remove_rule options: @@ -158,7 +158,7 @@ agents: position: 1 permission: DENY src_ip: 192.168.10.10 # dmz_server - dst_ip: 192.168.0.10 # client_1 + dest_ip: 192.168.0.10 # client_1 src_port: HTTP dst_port: HTTP protocol_name: UDP @@ -180,7 +180,7 @@ agents: position: 2 permission: DENY src_ip: 192.168.10.10 # dmz_server - dst_ip: 192.168.0.10 # client_1 + dest_ip: 192.168.0.10 # client_1 src_port: HTTP dst_port: HTTP protocol_name: TCP @@ -202,7 +202,7 @@ agents: position: 10 permission: DENY src_ip: 192.168.20.10 # external_computer - dst_ip: 192.168.10.10 # dmz + dest_ip: 192.168.10.10 # dmz src_port: POSTGRES_SERVER dst_port: POSTGRES_SERVER protocol_name: ICMP @@ -224,7 +224,7 @@ agents: position: 1 permission: DENY src_ip: 192.168.20.10 # external_computer - dst_ip: 192.168.0.10 # client_1 + dest_ip: 192.168.0.10 # client_1 src_port: NONE dst_port: NONE protocol_name: none diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index c4350e1f..a31f325a 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -56,7 +56,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy assert svc.health_state_visible == SoftwareHealthState.UNUSED # 2: Scan and check that the visible state is now correct - action = ("node_service_scan", {"node_name": "server_1", "service_name": "DNSServer"}) + action = ("node_service_scan", {"type":"node_service_scan" ,"node_name": "server_1", "service_name": "DNSServer"}) agent.store_action(action) game.step() assert svc.health_state_actual == SoftwareHealthState.GOOD @@ -67,7 +67,7 @@ def test_node_service_scan_integration(game_and_agent: Tuple[PrimaiteGame, Proxy assert svc.health_state_visible == SoftwareHealthState.GOOD # 4: Scan and check that the visible state is now correct - action = ("node_service_scan", {"node_name": "server_1", "service_name": "DNSServer"}) + action = ("node_service_scan", {"type":"node_service_scan", "node_name": "server_1", "service_name": "DNSServer"}) agent.store_action(action) game.step() assert svc.health_state_actual == SoftwareHealthState.COMPROMISED