From a3dc616126ae850c291aad4ac8cef0d4ddf88447 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 20 Nov 2024 17:19:35 +0000 Subject: [PATCH] #2869 - Starter changes in refactor of agent classes for refactor to become extensible. Identifiers added to classes and beginning of the inclusion of a ConfigSchema to base AbstractAgentClass --- src/primaite/game/agent/interface.py | 150 ++++++++++++------ .../scripted_agents/data_manipulation_bot.py | 21 ++- .../scripted_agents/probabilistic_agent.py | 2 +- .../agent/scripted_agents/random_agent.py | 4 +- .../game/agent/scripted_agents/tap001.py | 55 ++++--- src/primaite/game/game.py | 20 +++ tests/conftest.py | 2 +- 7 files changed, 174 insertions(+), 80 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 14b97821..7adaab69 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,7 +1,9 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK """Interface for agents.""" +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, ClassVar, Dict, List, Optional, Tuple, Type, TYPE_CHECKING, Union from gymnasium.core import ActType, ObsType from pydantic import BaseModel, model_validator @@ -69,7 +71,7 @@ class AgentSettings(BaseModel): """Settings for configuring the operation of an agent.""" start_settings: Optional[AgentStartSettings] = None - "Configuration for when an agent begins performing it's actions" + "Configuration for when an agent begins performing it's actions." flatten_obs: bool = True "Whether to flatten the observation space before passing it to the agent. True by default." action_masking: bool = False @@ -90,38 +92,78 @@ class AgentSettings(BaseModel): return cls(**config) -class AbstractAgent(ABC): +class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): """Base class for scripted and RL agents.""" - def __init__( - self, - agent_name: Optional[str], - action_space: Optional[ActionManager], - observation_space: Optional[ObservationManager], - reward_function: Optional[RewardFunction], - agent_settings: Optional[AgentSettings] = None, - ) -> None: - """ - Initialize an agent. + _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} + config: "AbstractAgent.ConfigSchema" + action_manager: Optional[ActionManager] + observation_manager: Optional[ObservationManager] + reward_function: Optional[RewardFunction] + + class ConfigSchema(BaseModel): + """ + Configuration Schema for AbstractAgents. + + :param type: Type of agent being generated. + :type type: str :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes. - :type agent_name: Optional[str] - :param action_space: Action space for the agent. - :type action_space: Optional[ActionManager] + :type agent_name: str :param observation_space: Observation space for the agent. :type observation_space: Optional[ObservationSpace] :param reward_function: Reward function for the agent. :type reward_function: Optional[RewardFunction] - :param agent_settings: Configurable Options for Abstracted Agents + :param agent_settings: Configurable Options for Abstracted Agents. :type agent_settings: Optional[AgentSettings] """ - self.agent_name: str = agent_name or "unnamed_agent" - self.action_manager: Optional[ActionManager] = action_space - self.observation_manager: Optional[ObservationManager] = observation_space - self.reward_function: Optional[RewardFunction] = reward_function - self.agent_settings = agent_settings or AgentSettings() - self.history: List[AgentHistoryItem] = [] - self.logger = AgentLog(agent_name) + + type: str + agent_name: ClassVar[str] + agent_settings = Optional[AgentSettings] = None + history: List[AgentHistoryItem] = [] + logger: AgentLog = AgentLog(agent_name) + + # def __init__( + # self, + # agent_name: Optional[str], + # action_space: Optional[ActionManager], + # observation_space: Optional[ObservationManager], + # reward_function: Optional[RewardFunction], + # agent_settings: Optional[AgentSettings] = None, + # ) -> None: + # """ + # Initialize an agent. + + # :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes. + # :type agent_name: Optional[str] + # :param action_space: Action space for the agent. + # :type action_space: Optional[ActionManager] + # :param observation_space: Observation space for the agent. + # :type observation_space: Optional[ObservationSpace] + # :param reward_function: Reward function for the agent. + # :type reward_function: Optional[RewardFunction] + # :param agent_settings: Configurable Options for Abstracted Agents + # :type agent_settings: Optional[AgentSettings] + # """ + # self.agent_name: str = agent_name or "unnamed_agent" + # self.action_manager: Optional[ActionManager] = action_space + # self.observation_manager: Optional[ObservationManager] = observation_space + # self.reward_function: Optional[RewardFunction] = reward_function + # self.agent_settings = agent_settings or AgentSettings() + # self.history: List[AgentHistoryItem] = [] + # self.logger = AgentLog(agent_name) + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Cannot create a new agent under reserved name {identifier}") + cls._registry[identifier] = cls + + @classmethod + def from_config(cls, config: Dict) -> "AbstractAgent": + """Creates an agent component from a configuration dictionary.""" + return cls(config=cls.ConfigSchema(**config)) def update_observation(self, state: Dict) -> ObsType: """ @@ -130,7 +172,7 @@ class AbstractAgent(ABC): state : dict state directly from simulation.describe_state output : dict state according to CAOS. """ - return self.observation_manager.update(state) + return self.config.observation_manager.update(state) def update_reward(self, state: Dict) -> float: """ @@ -141,7 +183,7 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.update(state=state, last_action_response=self.history[-1]) + return self.config.reward_function.update(state=state, last_action_response=self.history[-1]) @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -165,14 +207,14 @@ class AbstractAgent(ABC): # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" - request = self.action_manager.form_request(action_identifier=action, action_options=options) + request = self.config.action_manager.form_request(action_identifier=action, action_options=options) return request def process_action_response( self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse ) -> None: """Process the response from the most recent action.""" - self.history.append( + self.config.history.append( AgentHistoryItem( timestep=timestep, action=action, parameters=parameters, request=request, response=response ) @@ -180,10 +222,10 @@ class AbstractAgent(ABC): def save_reward_to_history(self) -> None: """Update the most recent history item with the reward value.""" - self.history[-1].reward = self.reward_function.current_reward + self.config.history[-1].reward = self.config.reward_function.current_reward -class AbstractScriptedAgent(AbstractAgent): +class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"): """Base class for actors which generate their own behaviour.""" @abstractmethod @@ -192,26 +234,34 @@ class AbstractScriptedAgent(AbstractAgent): return super().get_action(obs=obs, timestep=timestep) -class ProxyAgent(AbstractAgent): +class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"): """Agent that sends observations to an RL model and receives actions from that model.""" - def __init__( - self, - agent_name: Optional[str], - action_space: Optional[ActionManager], - observation_space: Optional[ObservationManager], - reward_function: Optional[RewardFunction], - agent_settings: Optional[AgentSettings] = None, - ) -> None: - super().__init__( - agent_name=agent_name, - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - ) - self.most_recent_action: ActType - self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - self.action_masking: bool = agent_settings.action_masking if agent_settings else False + class ConfigSchema(AbstractAgent.ConfigSchema): + """Configuration Schema for Proxy Agent.""" + + agent_settings = Union[AgentSettings | None] = None + most_reason_action: ActType + flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False + action_masking: bool = agent_settings.action_masking if agent_settings else False + + # def __init__( + # self, + # agent_name: Optional[str], + # action_space: Optional[ActionManager], + # observation_space: Optional[ObservationManager], + # reward_function: Optional[RewardFunction], + # agent_settings: Optional[AgentSettings] = None, + # ) -> None: + # super().__init__( + # agent_name=agent_name, + # action_space=action_space, + # observation_space=observation_space, + # reward_function=reward_function, + # ) + # self.most_recent_action: ActType + # self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False + # self.action_masking: bool = agent_settings.action_masking if agent_settings else False def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ @@ -224,7 +274,7 @@ class ProxyAgent(AbstractAgent): :return: Action to be taken in CAOS format. :rtype: Tuple[str, Dict] """ - return self.action_manager.get_action(self.most_recent_action) + return self.config.action_manager.get_action(self.most_recent_action) def store_action(self, action: ActType): """ @@ -232,4 +282,4 @@ class ProxyAgent(AbstractAgent): The environment is responsible for calling this method when it receives an action from the agent policy. """ - self.most_recent_action = action + self.config.most_recent_action = action diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index 129fac1a..55b2d08b 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -7,12 +7,22 @@ from gymnasium.core import ObsType from primaite.game.agent.interface import AbstractScriptedAgent -class DataManipulationAgent(AbstractScriptedAgent): +class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation_Agent"): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" next_execution_timestep: int = 0 starting_node_idx: int = 0 + config: "DataManipulationAgent.ConfigSchema" + + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration Schema for DataManipulationAgent.""" + + # TODO: Could be worth moving this to a "AbstractTAPAgent" + starting_node_name: str + starting_application_name: str + next_execution_timestep: int + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.setup_agent() @@ -38,12 +48,15 @@ class DataManipulationAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: - self.logger.debug(msg="Performing do NOTHING") - return "DONOTHING", {} + self.logger.debug(msg="Performing do nothing action") + return "do_nothing", {} self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) self.logger.info(msg="Performing a data manipulation attack!") - return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} + return "node_application_execute", { + "node_name": self.config.starting_node_name, + "application_name": self.config.starting_application_name, + } def setup_agent(self) -> None: """Set the next execution timestep when the episode resets.""" diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index cd44644f..b8df7838 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -12,7 +12,7 @@ from primaite.game.agent.observations.observation_manager import ObservationMana from primaite.game.agent.rewards import RewardFunction -class ProbabilisticAgent(AbstractScriptedAgent): +class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"): """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" class Settings(pydantic.BaseModel): diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index df9273f7..99b8a1e9 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -11,7 +11,7 @@ from primaite.game.agent.observations.observation_manager import ObservationMana from primaite.game.agent.rewards import RewardFunction -class RandomAgent(AbstractScriptedAgent): +class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"): """Agent that ignores its observation and acts completely at random.""" def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -27,7 +27,7 @@ class RandomAgent(AbstractScriptedAgent): return self.action_manager.get_action(self.action_manager.space.sample()) -class PeriodicAgent(AbstractScriptedAgent): +class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"): """Agent that does nothing most of the time, but executes application at regular intervals (with variance).""" class Settings(BaseModel): diff --git a/src/primaite/game/agent/scripted_agents/tap001.py b/src/primaite/game/agent/scripted_agents/tap001.py index c4f6062a..78cb9293 100644 --- a/src/primaite/game/agent/scripted_agents/tap001.py +++ b/src/primaite/game/agent/scripted_agents/tap001.py @@ -1,4 +1,6 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK +from __future__ import annotations + import random from typing import Dict, Tuple @@ -7,20 +9,27 @@ from gymnasium.core import ObsType from primaite.game.agent.interface import AbstractScriptedAgent -class TAP001(AbstractScriptedAgent): +class TAP001(AbstractScriptedAgent, identifier="TAP001"): """ TAP001 | Mobile Malware -- Ransomware Variant. Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application) """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.setup_agent() + # TODO: Link with DataManipulationAgent via a parent "TAP" agent class. - next_execution_timestep: int = 0 - starting_node_idx: int = 0 - installed: bool = False + config: "TAP001.ConfigSchema" + + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration Schema for TAP001 Agent.""" + + starting_node_name: str + next_execution_timestep: int = 0 + installed: bool = False + + # def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs) + # self.setup_agent() def _set_next_execution_timestep(self, timestep: int) -> None: """Set the next execution timestep with a configured random variance. @@ -28,9 +37,9 @@ class TAP001(AbstractScriptedAgent): :param timestep: The timestep to add variance to. """ random_timestep_increment = random.randint( - -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance + -self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance ) - self.next_execution_timestep = timestep + random_timestep_increment + self.config.next_execution_timestep = timestep + random_timestep_increment def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: """Waits until a specific timestep, then attempts to execute the ransomware application. @@ -45,28 +54,28 @@ class TAP001(AbstractScriptedAgent): :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ - if timestep < self.next_execution_timestep: - return "DONOTHING", {} + if timestep < self.config.next_execution_timestep: + return "do_nothing", {} - self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) + self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency) - if not self.installed: - self.installed = True - return "NODE_APPLICATION_INSTALL", { - "node_id": self.starting_node_idx, + if not self.config.installed: + self.config.installed = True + return "node_application_install", { + "node_name": self.config.starting_node_name, "application_name": "RansomwareScript", } - return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} + return "node_application_execute", {"node_name": self.config.starting_node_name, "application_id": 0} def setup_agent(self) -> None: """Set the next execution timestep when the episode resets.""" self._select_start_node() - self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) - for n, act in self.action_manager.action_map.items(): - if not act[0] == "NODE_APPLICATION_INSTALL": + self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step) + for n, act in self.config.action_manager.action_map.items(): + if not act[0] == "node_application_install": continue - if act[1]["node_id"] == self.starting_node_idx: + if act[1]["node_name"] == self.config.starting_node_name: self.ip_address = act[1]["ip_address"] return raise RuntimeError("TAP001 agent could not find database server ip address in action map") @@ -74,5 +83,7 @@ class TAP001(AbstractScriptedAgent): def _select_start_node(self) -> None: """Set the starting starting node of the agent to be a random node from this agent's action manager.""" # we are assuming that every node in the node manager has a data manipulation application at idx 0 - num_nodes = len(self.action_manager.node_names) + num_nodes = len(self.config.action_manager.node_names) + # TODO: Change this to something? self.starting_node_idx = random.randint(0, num_nodes - 1) + self.logger.debug(f"Selected Starting node ID: {self.starting_node_idx}") diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index c8fbac4e..2ef7b1c5 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -547,6 +547,26 @@ class PrimaiteGame: reward_function = RewardFunction.from_config(reward_function_cfg) # CREATE AGENT + + # TODO: MAKE THIS BIT WORK AND NOT THE IF/ELSE CHAIN OF HORRORS + + # Pass through: + # config + # action manager + # observation_manager + # reward_function + + new_agent_cfg = { + "action_manager": action_space, + "agent_name": agent_cfg["ref"], + "observation_manager": obs_space, + "agent_settings": agent_cfg.get("agent_settings", {}), + "reward_function": reward_function, + } + new_agent_cfg = agent_cfg["settings"] + # new_agent_cfg.update{} + new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=new_agent_cfg) + if agent_type == "ProbabilisticAgent": # TODO: implement non-random agents and fix this parsing settings = agent_cfg.get("agent_settings", {}) diff --git a/tests/conftest.py b/tests/conftest.py index 64fe0699..efdb515e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -264,7 +264,7 @@ def example_network() -> Network: return network -class ControlledAgent(AbstractAgent): +class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"): """Agent that can be controlled by the tests.""" def __init__(