diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 02134650..64cbe0cf 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -175,10 +175,6 @@ class NodeApplicationInstallAction(AbstractAction): class ConfigureDatabaseClientAction(AbstractAction): """Action which sets config parameters for a database client on a node.""" - model_config = ConfigDict(extra="forbid") - server_ip_address: Optional[str] = None - server_password: Optional[str] = None - class _Opts(BaseModel): """Schema for options that can be passed to this action.""" @@ -204,7 +200,6 @@ class ConfigureRansomwareScriptAction(AbstractAction): class _Opts(BaseModel, AbstractAction.ConfigSchema): """Schema for options that can be passed to this option.""" - node_name: str model_config = ConfigDict(extra="forbid") server_ip_address: Optional[str] = None server_password: Optional[str] = None diff --git a/src/primaite/game/agent/actions/abstract.py b/src/primaite/game/agent/actions/abstract.py index 2ed168d9..c18f0dbc 100644 --- a/src/primaite/game/agent/actions/abstract.py +++ b/src/primaite/game/agent/actions/abstract.py @@ -43,6 +43,5 @@ class AbstractAction(BaseModel): def from_config(cls, config: Dict) -> "AbstractAction": """Create an action component from a config dictionary.""" for attribute, value in config.items(): - if not hasattr(cls.ConfigSchema, attribute): - setattr(cls.ConfigSchema, attribute, value) + setattr(cls.ConfigSchema, attribute, value) return cls diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py index d6d5f4b4..3beface9 100644 --- a/src/primaite/game/agent/actions/acl.py +++ b/src/primaite/game/agent/actions/acl.py @@ -1,5 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import List, Literal +from typing import List, Literal, Union from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat @@ -92,6 +92,12 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r num_protocols: int num_permissions: int = 3 permission: str + target_firewall_nodename: str + src_ip: str + dst_ip: str + dst_wildcard: str + src_port: Union[int| None] + dst_port: Union[int | None] class ConfigSchema(ACLAbstractAction.ConfigSchema): """Configuration schema for FirewallACLAddRuleAction.""" @@ -102,29 +108,22 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r num_protocols: int num_permissions: int = 3 permission: str + target_firewall_nodename: str + src_ip: str + dst_ip: str + dst_wildcard: str + src_port: Union[int| None] @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if config.protocol_id == 0: + if config.protocol_name == None: return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS - if config.source_ip_id == 0: + if config.src_ip == 0: return ["do_nothing"] # invalid formulation - elif config.source_ip_id == 1: - src_ip = "ALL" - else: - # src_ip = self.manager.get_ip_address_by_idx(source_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - pass + if config.src_port == 0: + return ["do_nothing"] # invalid configuration. - if config.source_port_id == 0: - return ["do_nothing"] # invalid formulation - elif config.source_port_id == 1: - src_port = "ALL" - else: - # src_port = self.manager.get_port_by_idx(source_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - pass return [ "network", @@ -136,9 +135,9 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r "add_rule", config.permission, config.protocol_name, - str(src_ip), + config.src_ip, config.src_wildcard, - src_port, + config.src_port, config.dst_ip, config.dst_wildcard, config.dst_port, diff --git a/src/primaite/game/agent/actions/config.py b/src/primaite/game/agent/actions/config.py index d7b436d7..dc7e98b9 100644 --- a/src/primaite/game/agent/actions/config.py +++ b/src/primaite/game/agent/actions/config.py @@ -37,7 +37,7 @@ class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_rans if config.node_name is None: return ["do_nothing"] ConfigureRansomwareScriptAction._Opts.model_validate(config) # check that options adhere to schema - return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", config] + return ["network", "node", config.node_name, "application", "RansomwareScript", "configure", config.model_config] class ConfigureDoSBotAction(AbstractAction, identifier="configure_dos_bot"): diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index 88b09a29..a2b75be5 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -112,14 +112,14 @@ agents: firewall_port_name: internal firewall_port_direction: inbound position: 1 - permission: 1 - source_ip_id: 2 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 - source_wildcard_id: 0 - dest_wildcard_id: 0 + permission: PERMIT + src_ip: 192.168.0.10 + dst_ip: ALL + src_port: ALL + dst_port: ALL + protocol_name: ALL + src_wildcard: 0 + dst_wildcard: 0 2: action: firewall_acl_remove_rule options: @@ -134,12 +134,12 @@ agents: firewall_port_name: internal firewall_port_direction: outbound position: 1 - permission: 2 - source_ip_id: 2 # client 1 - dest_ip_id: 1 # ALL - source_port_id: 2 - dest_port_id: 3 - protocol_id: 2 + permission: DENY + src_ip: 192.168.0.10 # client 1 + dest_ip: ALL # ALL + src_port: ARP + dst_port: DNS + protocol_name: ICMP source_wildcard_id: 0 dest_wildcard_id: 0 4: @@ -156,12 +156,12 @@ agents: firewall_port_name: dmz firewall_port_direction: inbound position: 1 - permission: 2 - source_ip_id: 3 # dmz_server - dest_ip_id: 2 # client_1 - source_port_id: 4 - dest_port_id: 4 - protocol_id: 4 + permission: DENY + src_ip: 192.168.10.10 # dmz_server + dst_ip: 192.168.0.10 # client_1 + src_port: HTTP + dst_port: HTTP + protocol_name: UDP source_wildcard_id: 0 dest_wildcard_id: 0 6: @@ -178,12 +178,12 @@ agents: firewall_port_name: dmz firewall_port_direction: outbound position: 2 - permission: 2 - source_ip_id: 3 # dmz_server - dest_ip_id: 2 # client_1 - source_port_id: 4 - dest_port_id: 4 - protocol_id: 3 + permission: DENY + src_ip: 192.168.10.10 # dmz_server + dst_ip: 192.168.0.10 # client_1 + src_port: HTTP + dst_port: HTTP + protocol_name: TCP source_wildcard_id: 0 dest_wildcard_id: 0 8: @@ -200,12 +200,12 @@ agents: firewall_port_name: external firewall_port_direction: inbound position: 10 - permission: 2 - source_ip_id: 4 # external_computer - dest_ip_id: 3 # dmz - source_port_id: 5 - dest_port_id: 5 - protocol_id: 2 + permission: DENY + src_ip: 192.168.20.10 # external_computer + dst_ip: 192.168.10.10 # dmz + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + protocol_name: ICMP source_wildcard_id: 0 dest_wildcard_id: 0 10: @@ -222,12 +222,12 @@ agents: firewall_port_name: external firewall_port_direction: outbound position: 1 - permission: 2 - source_ip_id: 4 # external_computer - dest_ip_id: 2 # client_1 - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 + permission: DENY + src_ip: 192.168.20.10 # external_computer + dst_ip: 192.168.0.10 # client_1 + src_port: NONE + dst_port: NONE + protocol_name: none source_wildcard_id: 0 dest_wildcard_id: 0 12: diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index f380ba7d..c4350e1f 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -163,8 +163,9 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox }, ) agent.store_action(action) + print(agent.most_recent_action) game.step() - + print(agent.most_recent_action) # 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2 print(router.acl.show()) assert router.acl.num_rules == 6 @@ -653,9 +654,9 @@ def test_firewall_acl_add_remove_rule_integration(): assert firewall.external_outbound_acl.acl[1].action.name == "DENY" assert firewall.external_outbound_acl.acl[1].src_ip_address == IPv4Address("192.168.20.10") assert firewall.external_outbound_acl.acl[1].dst_ip_address == IPv4Address("192.168.0.10") - assert firewall.external_outbound_acl.acl[1].dst_port is None - assert firewall.external_outbound_acl.acl[1].src_port is None - assert firewall.external_outbound_acl.acl[1].protocol is None + assert firewall.external_outbound_acl.acl[1].dst_port == PORT_LOOKUP["NONE"] + assert firewall.external_outbound_acl.acl[1].src_port == PORT_LOOKUP["NONE"] + assert firewall.external_outbound_acl.acl[1].protocol == PROTOCOL_LOOKUP["NONE"] env.step(12) # Remove ACL rule from External Outbound assert firewall.external_outbound_acl.num_rules == 1