diff --git a/src/primaite/game/agent/actions/__init__.py b/src/primaite/game/agent/actions/__init__.py index 7f054591..016a09ba 100644 --- a/src/primaite/game/agent/actions/__init__.py +++ b/src/primaite/game/agent/actions/__init__.py @@ -29,4 +29,5 @@ __all__ = ( "node", "service", "session", + "ActionManager", ) diff --git a/src/primaite/game/agent/actions/abstract.py b/src/primaite/game/agent/actions/abstract.py index cd14ef6d..ef22ec54 100644 --- a/src/primaite/game/agent/actions/abstract.py +++ b/src/primaite/game/agent/actions/abstract.py @@ -36,6 +36,4 @@ class AbstractAction(BaseModel): @classmethod def from_config(cls, config: Dict) -> "AbstractAction": """Create an action component from a config dictionary.""" - if not config.get("type"): - config.update({"type": cls.__name__}) return cls(config=cls.ConfigSchema(**config)) diff --git a/src/primaite/game/agent/actions/acl.py b/src/primaite/game/agent/actions/acl.py index e8ad59f5..fb18d025 100644 --- a/src/primaite/game/agent/actions/acl.py +++ b/src/primaite/game/agent/actions/acl.py @@ -1,5 +1,7 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import List, Literal, Union +from __future__ import annotations + +from typing import List from primaite.game.agent.actions.manager import AbstractAction from primaite.interface.request import RequestFormat @@ -12,34 +14,48 @@ __all__ = ( ) -class ACLAbstractAction(AbstractAction, identifier="acl_abstract_action"): - """Base class for ACL actions.""" +class ACLAddRuleAbstractAction(AbstractAction, identifier="acl_add_rule_abstract_action"): + """Base abstract class for ACL add rule actions.""" + + config: ConfigSchema = "ACLAddRuleAbstractAction.ConfigSchema" class ConfigSchema(AbstractAction.ConfigSchema): - """Configuration Schema base for ACL abstract actions.""" + """Configuration Schema base for ACL add rule abstract actions.""" src_ip: str protocol_name: str + permission: str + position: int + src_ip: str + dst_ip: str + src_port: str + dst_port: str + src_wildcard: int + dst_wildcard: int -class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): +class ACLRemoveRuleAbstractAction(AbstractAction, identifier="acl_remove_rule_abstract_action"): + """Base abstract class for ACL remove rule actions.""" + + config: ConfigSchema = "ACLRemoveRuleAbstractAction.ConfigSchema" + + class ConfigSchema(AbstractAction.ConfigSchema): + """Configuration Schema base for ACL remove rule abstract actions.""" + + src_ip: str + protocol_name: str + position: int + + +class RouterACLAddRuleAction(ACLAddRuleAbstractAction, identifier="router_acl_add_rule"): """Action which adds a rule to a router's ACL.""" config: "RouterACLAddRuleAction.ConfigSchema" - class ConfigSchema(AbstractAction.ConfigSchema): + class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema): """Configuration Schema for RouterACLAddRuleAction.""" target_router: str - permission: str - protocol_name: str - position: int - src_ip: str - src_wildcard: int - source_port: str - dst_ip: str - dst_wildcard: int - dst_port: str @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: @@ -62,16 +78,15 @@ class RouterACLAddRuleAction(AbstractAction, identifier="router_acl_add_rule"): ] -class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_rule"): +class RouterACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="router_acl_remove_rule"): """Action which removes a rule from a router's ACL.""" config: "RouterACLRemoveRuleAction.ConfigSchema" - class ConfigSchema(AbstractAction.ConfigSchema): + class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema): """Configuration schema for RouterACLRemoveRuleAction.""" target_router: str - position: int @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: @@ -79,31 +94,22 @@ class RouterACLRemoveRuleAction(AbstractAction, identifier="router_acl_remove_ru return ["network", "node", config.target_router, "acl", "remove_rule", config.position] -class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_rule"): +class FirewallACLAddRuleAction(ACLAddRuleAbstractAction, identifier="firewall_acl_add_rule"): """Action which adds a rule to a firewall port's ACL.""" config: "FirewallACLAddRuleAction.ConfigSchema" - class ConfigSchema(ACLAbstractAction.ConfigSchema): + class ConfigSchema(ACLAddRuleAbstractAction.ConfigSchema): """Configuration schema for FirewallACLAddRuleAction.""" target_firewall_nodename: str firewall_port_name: str firewall_port_direction: str - position: int - permission: str - src_ip: str - dest_ip: str - src_port: str - dst_port: str - protocol_name: str - source_wildcard_id: int - dest_wildcard_id: int @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if config.protocol_name == None: + if config.protocol_name is None: return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS if config.src_ip == 0: return ["do_nothing"] # invalid formulation @@ -121,27 +127,26 @@ class FirewallACLAddRuleAction(ACLAbstractAction, identifier="firewall_acl_add_r config.permission, config.protocol_name, config.src_ip, - config.source_wildcard_id, + config.src_wildcard, config.src_port, - config.dest_ip, - config.dest_wildcard_id, + config.dst_ip, + config.dst_wildcard, config.dst_port, config.position, ] -class FirewallACLRemoveRuleAction(AbstractAction, identifier="firewall_acl_remove_rule"): +class FirewallACLRemoveRuleAction(ACLRemoveRuleAbstractAction, identifier="firewall_acl_remove_rule"): """Action which removes a rule from a firewall port's ACL.""" config: "FirewallACLRemoveRuleAction.ConfigSchema" - class ConfigSchema(AbstractAction.ConfigSchema): + class ConfigSchema(ACLRemoveRuleAbstractAction.ConfigSchema): """Configuration schema for FirewallACLRemoveRuleAction.""" target_firewall_nodename: str firewall_port_name: str firewall_port_direction: str - position: int @classmethod def form_request(cls, config: ConfigSchema) -> List[str]: diff --git a/src/primaite/game/agent/actions/application.py b/src/primaite/game/agent/actions/application.py index f515a8ec..91e34eae 100644 --- a/src/primaite/game/agent/actions/application.py +++ b/src/primaite/game/agent/actions/application.py @@ -34,8 +34,6 @@ class NodeApplicationAbstractAction(AbstractAction, identifier="node_application @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if config.node_name is None or config.application_name is None: - return ["do_nothing"] return [ "network", "node", @@ -103,8 +101,6 @@ class NodeApplicationInstallAction(NodeApplicationAbstractAction, identifier="no @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if config.node_name is None: - return ["do_nothing"] return [ "network", "node", @@ -129,8 +125,6 @@ class NodeApplicationRemoveAction(NodeApplicationAbstractAction, identifier="nod @classmethod def form_request(cls, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request which can be ingested by the PrimAITE simulation.""" - if config.node_name is None: - return ["do_nothing"] return [ "network", "node", diff --git a/src/primaite/game/agent/actions/config.py b/src/primaite/game/agent/actions/config.py index 7c72e57d..319cd212 100644 --- a/src/primaite/game/agent/actions/config.py +++ b/src/primaite/game/agent/actions/config.py @@ -1,8 +1,8 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Dict, List, Optional, Union +from typing import List, Optional, Union -from pydantic import BaseModel, ConfigDict, Field, field_validator, ValidationInfo +from pydantic import ConfigDict, Field, field_validator, ValidationInfo from primaite.game.agent.actions.manager import AbstractAction, ActionManager from primaite.interface.request import RequestFormat @@ -27,7 +27,6 @@ class ConfigureRansomwareScriptAction(AbstractAction, identifier="c2_server_rans class ConfigSchema(AbstractAction.ConfigSchema): """Configuration schema for ConfigureRansomwareScriptAction.""" - model_config = ConfigDict(extra="forbid") node_name: str server_ip_address: Optional[str] server_password: Optional[str] @@ -109,17 +108,7 @@ class ConfigureC2BeaconAction(AbstractAction, identifier="configure_c2_beacon"): @classmethod def form_request(self, config: ConfigSchema) -> RequestFormat: """Return the action formatted as a request that can be ingested by the simulation.""" - if config.node_name is None: - return ["do_nothing"] - configuration = ConfigureC2BeaconAction.ConfigSchema( - c2_server_ip_address=config.c2_server_ip_address, - keep_alive_frequency=config.keep_alive_frequency, - masquerade_port=config.masquerade_port, - masquerade_protocol=config.masquerade_protocol, - ) - - ConfigureC2BeaconAction.ConfigSchema.model_validate(configuration) # check that options adhere to schema - + configuration = [] return ["network", "node", config.node_name, "application", "C2Beacon", "configure", configuration] diff --git a/src/primaite/game/agent/actions/manager.py b/src/primaite/game/agent/actions/manager.py index a413f6dc..b89704f4 100644 --- a/src/primaite/game/agent/actions/manager.py +++ b/src/primaite/game/agent/actions/manager.py @@ -13,8 +13,7 @@ agents: from __future__ import annotations -import itertools -from typing import Dict, List, Literal, Optional, Tuple +from typing import Dict, List, Optional, Tuple from gymnasium import spaces @@ -45,7 +44,9 @@ class ActionManager: def __init__( self, actions: List[Dict], # stores list of actions available to agent - act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions + act_map: Optional[ + Dict[int, Dict] + ] = None, # allows restricting set of possible actions - TODO: Refactor to be a list? *args, **kwargs, ) -> None: @@ -79,43 +80,6 @@ class ActionManager: # make sure all numbers between 0 and N are represented as dict keys in action map assert all([i in self.action_map.keys() for i in range(len(self.action_map))]) - def _enumerate_actions( - self, - ) -> Dict[int, Tuple[str, Dict]]: - """Generate a list of all the possible actions that could be taken. - - This enumerates all actions all combinations of parameters you could choose for those actions. The output - of this function is intended to populate the self.action_map parameter in the situation where the user provides - a list of action types, but doesn't specify any subset of actions that should be made available to the agent. - - The enumeration relies on the Actions' `shape` attribute. - - :return: An action map maps consecutive integers to a combination of Action type and parameter choices. - An example output could be: - {0: ("do_nothing", {'dummy': 0}), - 1: ("node_os_scan", {'node_name': computer}), - 2: ("node_os_scan", {'node_name': server}), - 3: ("node_folder_scan", {'node_name:computer, folder_name:downloads}), - ... #etc... - } - :rtype: Dict[int, Tuple[AbstractAction, Dict]] - """ - all_action_possibilities = [] - for act_name, action in self.actions.items(): - param_names = list(action.shape.keys()) - num_possibilities = list(action.shape.values()) - possibilities = [range(n) for n in num_possibilities] - - param_combinations = list(itertools.product(*possibilities)) - all_action_possibilities.extend( - [ - (act_name, {param_names[i]: param_combinations[j][i] for i in range(len(param_names))}) - for j in range(len(param_combinations)) - ] - ) - - return {i: p for i, p in enumerate(all_action_possibilities)} - def get_action(self, action: int) -> Tuple[str, Dict]: """Produce action in CAOS format.""" """the agent chooses an action (as an integer), this is converted into an action in CAOS format""" @@ -125,8 +89,9 @@ class ActionManager: def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat: """Take action in CAOS format and use the execution definition to change it into PrimAITE request format.""" - act_obj = self.actions[action_identifier].from_config(config=action_options) - return act_obj.form_request(config=act_obj.config) + act_class = AbstractAction._registry[action_identifier] + config = act_class.ConfigSchema(**action_options) + return act_class.form_request(config=config) @property def space(self) -> spaces.Space: @@ -134,7 +99,7 @@ class ActionManager: return spaces.Discrete(len(self.action_map)) @classmethod - def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": + def from_config(cls, game: "PrimaiteGame", cfg: Dict) -> "ActionManager": # noqa: F821 """ Construct an ActionManager from a config definition.