From ed020f005fde98504749e0060241a512dcc83ce4 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 13 Nov 2024 10:40:51 +0000 Subject: [PATCH] #2912 - Pre-commit updates ahead of first draft PR. --- .../how_to_guides/extensible_actions.rst | 2 +- src/primaite/game/agent/actions.py | 2 +- src/primaite/game/agent/actions/acl.py | 174 ++---------------- src/primaite/game/agent/actions/config.py | 1 - src/primaite/game/agent/actions/file.py | 2 +- src/primaite/game/agent/actions/manager.py | 3 +- src/primaite/game/agent/actions/node.py | 8 + src/primaite/game/agent/actions/session.py | 4 +- .../game_layer/test_actions.py | 1 + 9 files changed, 36 insertions(+), 161 deletions(-) diff --git a/docs/source/how_to_guides/extensible_actions.rst b/docs/source/how_to_guides/extensible_actions.rst index a6c12303..bd78c8e1 100644 --- a/docs/source/how_to_guides/extensible_actions.rst +++ b/docs/source/how_to_guides/extensible_actions.rst @@ -29,7 +29,7 @@ New actions to be used within PrimAITE require: class ConfigSchema(AbstractAction.ConfigSchema): target_application: str - The ConfigSchema is used when the class is called to form the action. + The ConfigSchema is used when the class is called to form the action, within the `form_request` method, detailed below. #. **Unique Identifier**: diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 885e0238..02134650 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -297,7 +297,7 @@ class NodeFolderAbstractAction(AbstractAction): class NodeFolderScanAction(NodeFolderAbstractAction): """Action which scans a folder.""" - def __init__(self, manager: "ActionManager", node_name: str, folder_name, **kwargs) -> None: + def __init__(self, manager: "ActionManager", node_name: str, folder_name: str, **kwargs) -> None: super().__init__(manager, node_name=node_name, folder_name=folder_name, **kwargs) self.verb: str = "scan" diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py index 72a0b262..d6d5f4b4 100644 --- a/src/primaite/game/agent/actions/acl.py +++ b/src/primaite/game/agent/actions/acl.py @@ -1,9 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Dict, List, Literal +from typing import List, Literal -from pydantic import BaseModel, Field, field_validator, ValidationInfo - -from primaite.game.agent.actions.manager import AbstractAction, ActionManager +from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat __all__ = ( @@ -20,6 +18,9 @@ class ACLAbstractAction(AbstractAction, identifier="acl_abstract_action"): class ConfigSchema(AbstractAction.ConfigSchema): """Configuration Schema base for ACL abstract actions.""" + src_ip: str + protocol_name: str + class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): """Action which adds a rule to a router's ACL.""" @@ -27,13 +28,11 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): target_router: str position: int permission: Literal[1, 2] - src_ip: str source_wildcard_id: int source_port: str dst_ip: str dst_wildcard: int dst_port: int - protocol_name: str class ConfigSchema(AbstractAction.ConfigSchema): """Configuration Schema for RouterACLAddRuleAction.""" @@ -47,66 +46,10 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): dst_ip: str dst_wildcard: int dst_port: int - protocol_name: str - - class ACLRuleOptions(BaseModel): - """Validator for ACL_ADD_RULE options.""" - - target_router: str - """On which router to add the rule, must be specified.""" - position: int - """At what position to add the rule, must be specified.""" - permission: Literal[1, 2] - """Whether to allow or deny traffic, must be specified. 1 = PERMIT, 2 = DENY.""" - src_ip: str - """Rule source IP address. By default, all ip addresses.""" - source_wildcard_id: int = Field(default=0, ge=0) - """Rule source IP wildcard. By default, use the wildcard at index 0 from action manager.""" - source_port: int = Field(default=1, ge=1) - """Rule source port. By default, all source ports.""" - dst_ip_id: int = Field(default=1, ge=1) - """Rule destination IP address. By default, all ip addresses.""" - dst_wildcard: int = Field(default=0, ge=0) - """Rule destination IP wildcard. By default, use the wildcard at index 0 from action manager.""" - dst_port_id: int = Field(default=1, ge=1) - """Rule destination port. By default, all destination ports.""" - protocol_name: str = "ALL" - """Rule protocol. By default, all protocols.""" - - # @field_validator( - # "source_ip_id", - # "source_port_id", - # "source_wildcard_id", - # "dest_ip_id", - # "dest_port_id", - # "dest_wildcard_id", - # "protocol_name", - # mode="before", - # ) - # @classmethod - # def not_none(cls, v: str, info: ValidationInfo) -> int: - # """If None is passed, use the default value instead.""" - # if v is None: - # return cls.model_fields[info.field_name].default - # return v @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - # Validate incoming data. - # parsed_options = RouterACLAddRuleAction.ACLRuleOptions( - # target_router=config.target_router, - # position=config.position, - # permission=config.permission, - # src_ip=config.src_ip, - # source_wildcard_id=config.source_wildcard_id, - # dst_ip_id=config.dst_ip, - # dst_wildcard=config.dst_wildcard, - # source_port_id=config.source_port_id, - # dest_port=config.dst_port, - # protocol=config.protocol_name, - # ) - return [ "network", "node", @@ -118,7 +61,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): config.src_ip, config.src_wildcard, config.source_port, - str(config.dst_ip), + config.dst_ip, config.dst_wildcard, config.dst_port, config.position, @@ -160,63 +103,11 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r num_permissions: int = 3 permission: str - def __init__( - self, - manager: "ActionManager", - max_acl_rules: int, - num_ips: int, - num_ports: int, - num_protocols: int, - **kwargs, - ) -> None: - """Init method for FirewallACLAddRuleAction. - - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - :param num_ips: Number of IP addresses in the simulation. - :type num_ips: int - :param num_ports: Number of ports in the simulation. - :type num_ports: int - :param num_protocols: Number of protocols in the simulation. - :type num_protocols: int - """ - super().__init__(manager=manager) - num_permissions = 3 - self.shape: Dict[str, int] = { - "position": max_acl_rules, - "permission": num_permissions, - "source_ip_id": num_ips, - "dest_ip_id": num_ips, - "source_port_id": num_ports, - "dest_port_id": num_ports, - "protocol_id": num_protocols, - } - @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if config.permission == 0: - permission_str = "UNUSED" - return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS - elif config.permission == 1: - permission_str = "PERMIT" - elif config.permission == 2: - permission_str = "DENY" - # else: - # _LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.") - if config.protocol_id == 0: return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS - - if config.protocol_id == 1: - protocol = "ALL" - else: - # protocol = self.manager.get_internet_protocol_by_idx(protocol_id - 2) - # subtract 2 to account for UNUSED=0 and ALL=1. - pass - if config.source_ip_id == 0: return ["do_nothing"] # invalid formulation elif config.source_ip_id == 1: @@ -235,26 +126,6 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r # subtract 2 to account for UNUSED=0, and ALL=1 pass - if config.dest_ip_id == 0: - return ["do_nothing"] # invalid formulation - elif config.dest_ip_id == 1: - dst_ip = "ALL" - else: - # dst_ip = self.manager.get_ip_address_by_idx(dest_ip_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - pass - - if config.dest_port_id == 0: - return ["do_nothing"] # invalid formulation - elif config.dest_port_id == 1: - dst_port = "ALL" - else: - # dst_port = self.manager.get_port_by_idx(dest_port_id - 2) - # subtract 2 to account for UNUSED=0, and ALL=1 - # src_wildcard = self.manager.get_wildcard_by_idx(source_wildcard_id) - # dst_wildcard = self.manager.get_wildcard_by_idx(dest_wildcard_id) - pass - return [ "network", "node", @@ -264,13 +135,13 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r "acl", "add_rule", config.permission, - protocol, + config.protocol_name, str(src_ip), config.src_wildcard, src_port, - str(dst_ip), + config.dst_ip, config.dst_wildcard, - dst_port, + config.dst_port, config.position, ] @@ -278,29 +149,24 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r class FirewallACLRemoveRuleAction(AbstractAction, identifier="firewall_acl_remove_rule"): """Action which removes a rule from a firewall port's ACL.""" - def __init__(self, manager: "ActionManager", max_acl_rules: int, **kwargs) -> None: - """Init method for RouterACLRemoveRuleAction. + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for FirewallACLRemoveRuleAction.""" - :param manager: Reference to the ActionManager which created this action. - :type manager: ActionManager - :param max_acl_rules: Maximum number of ACL rules that can be added to the router. - :type max_acl_rules: int - """ - super().__init__(manager=manager) - self.shape: Dict[str, int] = {"position": max_acl_rules} + target_firewall_nodename: str + firewall_port_name: str + firewall_port_direction: str + position: int @classmethod - def form_request( - cls, target_firewall_nodename: str, firewall_port_name: str, firewall_port_direction: str, position: int - ) -> List[str]: + def form_request(cls, config: ConfigSchema) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" return [ "network", "node", - target_firewall_nodename, - firewall_port_name, - firewall_port_direction, + config.target_firewall_nodename, + config.firewall_port_name, + config.firewall_port_direction, "acl", "remove_rule", - position, + config.position, ] diff --git a/src/primaite/game/agent/actions/config.py b/src/primaite/game/agent/actions/config.py index a4247e21..d7b436d7 100644 --- a/src/primaite/game/agent/actions/config.py +++ b/src/primaite/game/agent/actions/config.py @@ -250,7 +250,6 @@ class ConfigureDatabaseClientAction(AbstractAction, identifier="configure_databa def form_request(self, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" - if config.node_name is None: return ["do_nothing"] return ["network", "node", config.node_name, "application", "DatabaseClient", "configure", config.model_config] diff --git a/src/primaite/game/agent/actions/file.py b/src/primaite/game/agent/actions/file.py index c5ba1602..5d12b27a 100644 --- a/src/primaite/game/agent/actions/file.py +++ b/src/primaite/game/agent/actions/file.py @@ -178,7 +178,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction, identifier="node_file_chec class NodeFileRepairAction(NodeFileAbstractAction, identifier="node_file_repair"): - """Action which repairs a file""" + """Action which repairs a file.""" verb: str = "repair" diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py index 764b5b6e..6c8353b0 100644 --- a/src/primaite/game/agent/actions/manager.py +++ b/src/primaite/game/agent/actions/manager.py @@ -149,7 +149,8 @@ class ActionManager: Since the agent uses a discrete action space which acts as a flattened version of the component-based action space, action_map provides a mapping between an integer (chosen by the agent) and a meaningful action and values of parameters. For example action 0 can correspond to do nothing, action 1 can - correspond to "node_service_scan" with ``node_name="server"`` and ``service_name="WebBrowser"``, action 2 can be " + correspond to "node_service_scan" with ``node_name="server"`` and + ``service_name="WebBrowser"``, action 2 can be " 3. ``options`` ``options`` contains a dictionary of options which are passed to the ActionManager's __init__ method. These options are used to calculate the shape of the action space, and to provide additional information diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py index ab4209e5..a69a8a5f 100644 --- a/src/primaite/game/agent/actions/node.py +++ b/src/primaite/game/agent/actions/node.py @@ -89,7 +89,11 @@ class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_acti class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"): + """Action which performs an NMAP ping scan.""" + class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): + """Configuration schema for NodeNMAPPingScanAction.""" + pass @classmethod @@ -110,6 +114,8 @@ class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_ """Action which performs an NMAP port scan.""" class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): + """Configuration Schema for NodeNMAPPortScanAction.""" + source_node: str target_protocol: Optional[Union[str, List[str]]] = (None,) target_port: Optional[Union[str, List[str]]] = (None,) @@ -141,6 +147,8 @@ class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_net """Action which performs an NMAP network service recon (ping scan followed by port scan).""" class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration schema for NodeNetworkServiceReconAction.""" + target_protocol: Optional[Union[str, List[str]]] = (None,) target_port: Optional[Union[str, List[str]]] = (None,) show: Optional[bool] = (False,) diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py index 79ff0705..dcae8b47 100644 --- a/src/primaite/game/agent/actions/session.py +++ b/src/primaite/game/agent/actions/session.py @@ -67,14 +67,14 @@ class NodeSessionsRemoteLogoutAction(NodeSessionAbstractAction, identifier="node class ConfigSchema(NodeSessionAbstractAction.ConfigSchema): """Configuration schema for NodeSessionsRemoteLogoutAction.""" - pass + verb: str = "remote_logoff" @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" if config.node_name is None or config.remote_ip is None: return ["do_nothing"] - return ["network", "node", config.node_name, "service", "Terminal", "remote_logoff", config.remote_ip] + return ["network", "node", config.node_name, "service", "Terminal", config.verb, config.remote_ip] class NodeAccountChangePasswordAction(NodeSessionAbstractAction, identifier="node_account_change_password"): diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index baa4c725..f380ba7d 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -166,6 +166,7 @@ def test_router_acl_addrule_integration(game_and_agent: Tuple[PrimaiteGame, Prox game.step() # 5: Check that the ACL now has 6 rules, but that server_1 can still ping server_2 + print(router.acl.show()) assert router.acl.num_rules == 6 assert server_1.ping("10.0.2.3") # Can ping server_2