From 518b934e09086d7dcc1c9ebc7e0b763dbbf85b5e Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Tue, 22 Oct 2024 17:02:54 +0100 Subject: [PATCH] #2912 - Corrections to some actions & fixing some linting. TODO: Action Manager errors --- src/primaite/game/agent/actions/acl.py | 38 ++++---- src/primaite/game/agent/actions/config.py | 2 +- src/primaite/game/agent/actions/manager.py | 14 +-- src/primaite/game/agent/actions/node.py | 97 ++++++++++++++++++- src/primaite/game/agent/actions/service.py | 7 +- src/primaite/game/agent/actions/session.py | 7 +- src/primaite/notebooks/Action-masking.ipynb | 7 +- .../Command-&-Control-E2E-Demonstration.ipynb | 4 +- .../notebooks/Training-an-SB3-Agent.ipynb | 4 +- .../create-simulation_demo.ipynb | 6 +- .../network_simulator_demo.ipynb | 4 +- 11 files changed, 149 insertions(+), 41 deletions(-) diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py index 1048dc1e..cc89bfba 100644 --- a/src/primaite/game/agent/actions/acl.py +++ b/src/primaite/game/agent/actions/acl.py @@ -3,10 +3,16 @@ from typing import Dict, List, Literal from pydantic import BaseModel, Field, field_validator, ValidationInfo -from primaite.game.agent.actions.manager import AbstractAction +from primaite.game.agent.actions.manager import AbstractAction, ActionManager from primaite.interface.request import RequestFormat -__all__ = ("RouterACLAddRuleAction", "RouterACLRemoveRuleAction", "FirewallACLAddRuleAction", "FirewallACLRemoveRuleAction") +__all__ = ( + "RouterACLAddRuleAction", + "RouterACLRemoveRuleAction", + "FirewallACLAddRuleAction", + "FirewallACLRemoveRuleAction", +) + class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): """Action which adds a rule to a router's ACL.""" @@ -23,8 +29,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): dest_ip_id: int dest_wildcard_id: int dest_port_id: int - protocol_id: int - + protocol_name: str class ACLRuleOptions(BaseModel): """Validator for ACL_ADD_RULE options.""" @@ -68,10 +73,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): return v @classmethod - def form_request( - cls, - config: ConfigSchema - ) -> List[str]: + def form_request(cls, config: ConfigSchema) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" # Validate incoming data. parsed_options = RouterACLAddRuleAction.ACLRuleOptions( @@ -84,7 +86,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): dest_wildcard_id=config.dest_wildcard_id, source_port_id=config.source_port_id, dest_port_id=config.dest_port_id, - protocol_id=config.protocol_id, + protocol=config.protocol_name, ) if parsed_options.permission == 1: permission_str = "PERMIT" @@ -96,40 +98,40 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): if parsed_options.protocol_id == 1: protocol = "ALL" else: - protocol = self.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2) + protocol = cls.manager.get_internet_protocol_by_idx(parsed_options.protocol_id - 2) # subtract 2 to account for UNUSED=0 and ALL=1. if parsed_options.source_ip_id == 1: src_ip = "ALL" else: - src_ip = self.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2) + src_ip = cls.manager.get_ip_address_by_idx(parsed_options.source_ip_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 - src_wildcard = self.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id) + src_wildcard = cls.manager.get_wildcard_by_idx(parsed_options.source_wildcard_id) if parsed_options.source_port_id == 1: src_port = "ALL" else: - src_port = self.manager.get_port_by_idx(parsed_options.source_port_id - 2) + src_port = cls.manager.get_port_by_idx(parsed_options.source_port_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 if parsed_options.dest_ip_id == 1: dst_ip = "ALL" else: - dst_ip = self.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2) + dst_ip = cls.manager.get_ip_address_by_idx(parsed_options.dest_ip_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 - dst_wildcard = self.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id) + dst_wildcard = cls.manager.get_wildcard_by_idx(parsed_options.dest_wildcard_id) if parsed_options.dest_port_id == 1: dst_port = "ALL" else: - dst_port = self.manager.get_port_by_idx(parsed_options.dest_port_id - 2) + dst_port = cls.manager.get_port_by_idx(parsed_options.dest_port_id - 2) # subtract 2 to account for UNUSED=0, and ALL=1 return [ "network", "node", - target_router, + config.target_router, "acl", "add_rule", permission_str, @@ -140,7 +142,7 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): str(dst_ip), dst_wildcard, dst_port, - position, + config.position, ] diff --git a/src/primaite/game/agent/actions/config.py b/src/primaite/game/agent/actions/config.py index e92d443b..582e8ec7 100644 --- a/src/primaite/game/agent/actions/config.py +++ b/src/primaite/game/agent/actions/config.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo -from primaite.game.agent.actions.manager import AbstractAction +from primaite.game.agent.actions.manager import AbstractAction, ActionManager from primaite.interface.request import RequestFormat __all__ = ( diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py index 2f47ea7c..7677b39a 100644 --- a/src/primaite/game/agent/actions/manager.py +++ b/src/primaite/game/agent/actions/manager.py @@ -87,6 +87,8 @@ class ActionManager: # ip_list: List[str] = [], # to allow us to map an index to an ip address. # wildcard_list: List[str] = [], # to allow mapping from wildcard index to act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions + *args, + **kwargs, ) -> None: """Init method for ActionManager. @@ -116,27 +118,27 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.node_names: List[str] = [n["node_name"] for n in nodes] + # self.node_names: List[str] = [n["node_name"] for n in nodes] """List of node names in this action space. The list order is the mapping between node index and node name.""" - self.application_names: List[List[str]] = [] + # self.application_names: List[List[str]] = [] """ List of applications per node. The list order gives the two-index mapping between (node_id, app_id) to app name. The first index corresponds to node id, the second index is the app id on that particular node. For instance, self.application_names[0][2] is the name of the third application on the first node. """ - self.service_names: List[List[str]] = [] + # self.service_names: List[List[str]] = [] """ List of services per node. The list order gives the two-index mapping between (node_id, svc_id) to svc name. The first index corresponds to node id, the second index is the service id on that particular node. For instance, self.service_names[0][2] is the name of the third service on the first node. """ - self.folder_names: List[List[str]] = [] + # self.folder_names: List[List[str]] = [] """ List of folders per node. The list order gives the two-index mapping between (node_id, folder_id) to folder name. The first index corresponds to node id, the second index is the folder id on that particular node. For instance, self.folder_names[0][2] is the name of the third folder on the first node. """ - self.file_names: List[List[List[str]]] = [] + # self.file_names: List[List[List[str]]] = [] """ List of files per folder per node. The list order gives the three-index mapping between (node_id, folder_id, file_id) to file name. The first index corresponds to node id, the second index is the @@ -203,7 +205,7 @@ class ActionManager: # and `options` is an optional dict of options to pass to the init method of the action class act_type = act_spec.get("type") act_options = act_spec.get("options", {}) - self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options) + # self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options) self.action_map: Dict[int, Tuple[str, Dict]] = {} """ diff --git a/src/primaite/game/agent/actions/node.py b/src/primaite/game/agent/actions/node.py index 011ff4dc..f95ba6df 100644 --- a/src/primaite/game/agent/actions/node.py +++ b/src/primaite/game/agent/actions/node.py @@ -1,5 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import ClassVar +from abc import abstractmethod +from typing import ClassVar, List, Optional, Union from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat @@ -61,3 +62,97 @@ class NodeResetAction(NodeAbstractAction, identifier="node_reset"): """Configuration schema for NodeResetAction.""" verb: str = "reset" + + +class NodeNMAPAbstractAction(AbstractAction, identifier="node_nmap_abstract_action"): + """Base class for NodeNMAP actions.""" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Base Configuration Schema for NodeNMAP actions.""" + + target_ip_address: Union[str, List[str]] + show: bool = False + node_name: str + + @classmethod + @abstractmethod + def form_request(cls, config: ConfigSchema) -> RequestFormat: + pass + + +class NodeNMAPPingScanAction(NodeNMAPAbstractAction, identifier="node_nmap_ping_scan"): + class ConfigSchema(NodeNMAPAbstractAction.ConfigSchema): + pass + + @classmethod + def form_request(cls, config: ConfigSchema) -> List[str]: # noqa + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.node_name, + "application", + "NMAP", + "ping_scan", + {"target_ip_address": config.target_ip_address, "show": config.show}, + ] + + +class NodeNMAPPortScanAction(NodeNMAPAbstractAction, identifier="node_nmap_port_scan"): + """Action which performs an NMAP port scan.""" + + class ConfigSchema(AbstractAction.ConfigSchema): + target_protocol: Optional[Union[str, List[str]]] = (None,) + target_port: Optional[Union[str, List[str]]] = (None,) + show: Optional[bool] = (False,) + + @classmethod + def form_request( + cls, + config: ConfigSchema, + ) -> List[str]: # noqa + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.source_node, + "application", + "NMAP", + "port_scan", + { + "target_ip_address": config.target_ip_address, + "target_port": config.target_port, + "target_protocol": config.target_protocol, + "show": config.show, + }, + ] + + +class NodeNetworkServiceReconAction(NodeNMAPAbstractAction, identifier="node_network_service_recon"): + """Action which performs an NMAP network service recon (ping scan followed by port scan).""" + + class ConfigSchema(AbstractAction.ConfigSchema): + target_protocol: Optional[Union[str, List[str]]] = (None,) + target_port: Optional[Union[str, List[str]]] = (None,) + show: Optional[bool] = (False,) + + @classmethod + def form_request( + cls, + config: ConfigSchema, + ) -> List[str]: # noqa + """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" + return [ + "network", + "node", + config.source_node, + "application", + "NMAP", + "network_service_recon", + { + "target_ip_address": config.target_ip_address, + "target_port": config.target_port, + "target_protocol": config.target_protocol, + "show": config.show, + }, + ] diff --git a/src/primaite/game/agent/actions/service.py b/src/primaite/game/agent/actions/service.py index cf277b5d..bccfaba2 100644 --- a/src/primaite/game/agent/actions/service.py +++ b/src/primaite/game/agent/actions/service.py @@ -18,6 +18,11 @@ __all__ = ( class NodeServiceAbstractAction(AbstractAction, identifier="node_service_abstract"): + """Abstract Action for Node Service related actions. + + Any actions which use node_name and service_name can inherit from this class. + """ + class ConfigSchema(AbstractAction.ConfigSchema): node_name: str service_name: str @@ -34,7 +39,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction, identifier="node_service_ """Action which scans a service.""" class ConfigSchema(NodeServiceAbstractAction.ConfigSchema): - """Configuration Schema for NodeServiceScanAction""" + """Configuration Schema for NodeServiceScanAction.""" verb: str = "scan" diff --git a/src/primaite/game/agent/actions/session.py b/src/primaite/game/agent/actions/session.py index eb035ff3..f77a85b1 100644 --- a/src/primaite/game/agent/actions/session.py +++ b/src/primaite/game/agent/actions/session.py @@ -19,8 +19,11 @@ class NodeSessionAbstractAction(AbstractAction, identifier="node_session_abstrac @classmethod @abstractmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: - """Abstract method. Should return the action formatted as a request which - can be ingested by the PrimAITE simulation.""" + """ + Abstract method for request forming. + + Should return the action formatted as a request which can be ingested by the PrimAITE simulation. + """ pass diff --git a/src/primaite/notebooks/Action-masking.ipynb b/src/primaite/notebooks/Action-masking.ipynb index ba70f2b4..d22e171d 100644 --- a/src/primaite/notebooks/Action-masking.ipynb +++ b/src/primaite/notebooks/Action-masking.ipynb @@ -19,7 +19,8 @@ "source": [ "from primaite.session.environment import PrimaiteGymEnv\n", "from primaite.config.load import data_manipulation_config_path\n", - "from prettytable import PrettyTable\n" + "from prettytable import PrettyTable\n", + "UDP=\"UDP\"" ] }, { @@ -195,7 +196,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -209,7 +210,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb index 6e6819fa..a697ca3e 100644 --- a/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Command-&-Control-E2E-Demonstration.ipynb @@ -1783,7 +1783,7 @@ "from primaite.simulator.network.transmission.network_layer import IPProtocol\n", "from primaite.simulator.network.transmission.transport_layer import Port\n", "# As we're configuring via the PrimAITE API we need to pass the actual IPProtocol/Port (Agents leverage the simulation via the game layer and thus can pass strings).\n", - "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol["UDP"], masquerade_port=Port["DNS"])\n", + "c2_beacon.configure(c2_server_ip_address=\"192.168.10.21\", masquerade_protocol=IPProtocol[\"UDP\"], masquerade_port=Port[\"DNS\"])\n", "c2_beacon.establish()\n", "c2_beacon.show()" ] @@ -1804,7 +1804,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, diff --git a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb index 892736fe..5255b0ad 100644 --- a/src/primaite/notebooks/Training-an-SB3-Agent.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -168,7 +168,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -182,7 +182,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb index f573f251..30417b84 100644 --- a/src/primaite/simulator/_package_data/create-simulation_demo.ipynb +++ b/src/primaite/simulator/_package_data/create-simulation_demo.ipynb @@ -182,7 +182,7 @@ "metadata": {}, "outputs": [], "source": [ - "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port["HTTP"], protocol = IPProtocol["NONE"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" + "mspaint = MSPaint(name = \"mspaint\", health_state_actual=SoftwareHealthState.GOOD, health_state_visible=SoftwareHealthState.GOOD, criticality=SoftwareCriticality.MEDIUM, port=Port[\"HTTP\"], protocol = IPProtocol[\"NONE\"],operating_state=ApplicationOperatingState.RUNNING,execution_control_status='manual', file_system=FileSystem(sys_log=SysLog(hostname=\"Test\"), sim_root=Path(__name__).parent),)" ] }, { @@ -249,7 +249,7 @@ ], "metadata": { "kernelspec": { - "display_name": "venv", + "display_name": ".venv", "language": "python", "name": "python3" }, @@ -263,7 +263,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.11" } }, "nbformat": 4, diff --git a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb index 2d5b4772..8406dbdf 100644 --- a/src/primaite/simulator/_package_data/network_simulator_demo.ipynb +++ b/src/primaite/simulator/_package_data/network_simulator_demo.ipynb @@ -537,7 +537,7 @@ "from primaite.simulator.network.hardware.nodes.network.router import ACLAction\n", "network.get_node_by_hostname(\"router_1\").acl.add_rule(\n", " action=ACLAction.DENY,\n", - " protocol=IPProtocol["ICMP"],\n", + " protocol=IPProtocol[\"ICMP\"],\n", " src_ip_address=\"192.168.10.22\",\n", " position=1\n", ")" @@ -650,7 +650,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": ".venv", "language": "python", "name": "python3" },