From a4fbd29bb4de8762423add2ed7581eb002e7aa63 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Mon, 16 Dec 2024 15:57:00 +0000 Subject: [PATCH] #2869 - Updates to agents to make sure they can be generated from a given config. Updates to test suite to reflect code changes --- .../agent/scripted_agents/abstract_tap.py | 6 ++-- .../scripted_agents/data_manipulation_bot.py | 22 +++++++++------ .../game/agent/scripted_agents/interface.py | 20 ++++++------- .../scripted_agents/probabilistic_agent.py | 26 ++++++++--------- .../agent/scripted_agents/random_agent.py | 4 +-- src/primaite/game/game.py | 20 ++++++++----- tests/conftest.py | 20 +++++++------ .../_game/_agent/test_probabilistic_agent.py | 28 +++++++++++++------ 8 files changed, 83 insertions(+), 63 deletions(-) diff --git a/src/primaite/game/agent/scripted_agents/abstract_tap.py b/src/primaite/game/agent/scripted_agents/abstract_tap.py index fb3f1688..95769624 100644 --- a/src/primaite/game/agent/scripted_agents/abstract_tap.py +++ b/src/primaite/game/agent/scripted_agents/abstract_tap.py @@ -3,11 +3,11 @@ from __future__ import annotations import random from abc import abstractmethod -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple from gymnasium.core import ObsType -from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent +from primaite.game.agent.scripted_agents.interface import AbstractAgent, AbstractScriptedAgent class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"): @@ -20,7 +20,7 @@ class AbstractTAPAgent(AbstractScriptedAgent, identifier="Abstract_TAP"): class ConfigSchema(AbstractScriptedAgent.ConfigSchema): """Configuration schema for Abstract TAP agents.""" - starting_node_name: str + starting_node_name: Optional[str] = None @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index a8b8d292..0f687367 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -1,12 +1,12 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Dict, Tuple +from typing import Any, Dict, Optional, Tuple from gymnasium.core import ObsType from primaite.game.agent.scripted_agents.abstract_tap import AbstractTAPAgent -class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agent"): +class DataManipulationAgent(AbstractTAPAgent, identifier="RedDatabaseCorruptingAgent"): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" config: "DataManipulationAgent.ConfigSchema" @@ -14,12 +14,12 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen class ConfigSchema(AbstractTAPAgent.ConfigSchema): """Configuration Schema for DataManipulationAgent.""" + starting_application_name: Optional[str] = None - starting_application_name: str - - def __init__(self) -> None: - """Initialise DataManipulationAgent.""" - self.setup_agent() + # def __init__(self, **kwargs: Any) -> None: + # """Initialise DataManipulationAgent.""" + # # self.setup_agent() + # super().__init_subclass__(**kwargs) @property def next_execution_timestep(self) -> int: @@ -41,11 +41,15 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen :return: Action formatted in CAOS format :rtype: Tuple[str, Dict] """ + if self.starting_node_name or self.config is None: + self.setup_agent() + self.get_action(obs=obs, timestep=timestep) + if timestep < self.next_execution_timestep: self.logger.debug(msg="Performing do nothing action") return "do_nothing", {} - self._set_next_execution_timestep(timestep + self.config._agent_settings.start_settings.frequency) + self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency) self.logger.info(msg="Performing a data manipulation attack!") return "node_application_execute", { "node_name": self.config.starting_node_name, @@ -55,4 +59,4 @@ class DataManipulationAgent(AbstractTAPAgent, identifier="Data_Manipulation_Agen def setup_agent(self) -> None: """Set the next execution timestep when the episode resets.""" self._select_start_node() - self._set_next_execution_timestep(self.config._agent_settings.start_settings.start_step) + self._set_next_execution_timestep(self.config.agent_settings.start_settings.start_step) diff --git a/src/primaite/game/agent/scripted_agents/interface.py b/src/primaite/game/agent/scripted_agents/interface.py index bc083ecf..5e9167f5 100644 --- a/src/primaite/game/agent/scripted_agents/interface.py +++ b/src/primaite/game/agent/scripted_agents/interface.py @@ -115,14 +115,14 @@ class AbstractAgent(BaseModel): :type agent_settings: Optional[AgentSettings] """ - agent_name: str = "Abstract_Agent" - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="forbid", arbitrary_types_allowed=True) + agent_name: Optional[str] = "Abstract_Agent" history: List[AgentHistoryItem] = [] _logger: AgentLog = AgentLog(agent_name=agent_name) - _action_manager: Optional[ActionManager] = None - _observation_manager: Optional[ObservationManager] = None - _reward_function: Optional[RewardFunction] = None - _agent_settings: Optional[AgentSettings] = None + action_manager: ActionManager + observation_manager: ObservationManager + reward_function: RewardFunction + agent_settings: Optional[AgentSettings] = None def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: if identifier in cls._registry: @@ -138,17 +138,17 @@ class AbstractAgent(BaseModel): @property def observation_manager(self) -> ObservationManager: """Returns the agents observation manager.""" - return self.config._observation_manager + return self.config.observation_manager @property def action_manager(self) -> ActionManager: """Returns the agents action manager.""" - return self.config._action_manager + return self.config.action_manager @property def reward_function(self) -> RewardFunction: """Returns the agents reward function.""" - return self.config._reward_function + return self.config.reward_function @classmethod def from_config(cls, config: Dict) -> "AbstractAgent": @@ -232,7 +232,7 @@ class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent") return super().get_action(obs=obs, timestep=timestep) -class ProxyAgent(AbstractAgent, identifier="Proxy_Agent"): +class ProxyAgent(AbstractAgent, identifier="ProxyAgent"): """Agent that sends observations to an RL model and receives actions from that model.""" config: "ProxyAgent.ConfigSchema" diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index c2d7d580..750a120f 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -1,28 +1,23 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK """Agents with predefined behaviours.""" -from typing import Dict, Tuple +from typing import Any, Dict, Tuple import numpy as np import pydantic from gymnasium.core import ObsType -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent +from primaite.game.agent.scripted_agents.interface import AbstractScriptedAgent, AgentSettings class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent"): """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" config: "ProbabilisticAgent.ConfigSchema" - agent_name: str = "ProbabilisticAgent" + rng: Any = np.random.default_rng(np.random.randint(0, 65535)) - class ConfigSchema(AbstractScriptedAgent.ConfigSchema): - """Configuration schema for Probabilistic Agent.""" - - action_space: ActionManager + class AgentSettings(AgentSettings): action_probabilities: Dict[int, float] """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" - @pydantic.field_validator("action_probabilities", mode="after") @classmethod def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]: @@ -42,16 +37,17 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="ProbabilisticAgent") ) return v - # def __init__(self, **kwargs) -> None: - # rng_seed = np.random.randint(0, 65535) - # self.rng = np.random.default_rng(rng_seed) - # self.logger.debug(f"ProbabilisticAgent RNG seed: {rng_seed}") - # super().__init_subclass__(**kwargs) + class ConfigSchema(AbstractScriptedAgent.ConfigSchema): + """Configuration schema for Probabilistic Agent.""" + + agent_name: str = "ProbabilisticAgent" + agent_settings: "ProbabilisticAgent.AgentSettings" + @property def probabilities(self) -> Dict[str, int]: """Convenience method to view the probabilities of the Agent.""" - return np.asarray(list(self.config.action_probabilities.values())) + return np.asarray(list(self.config.agent_settings.action_probabilities.values())) def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py index b0c0f7ce..d28069e6 100644 --- a/src/primaite/game/agent/scripted_agents/random_agent.py +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -13,7 +13,7 @@ class RandomAgent(AbstractScriptedAgent, identifier="Random_Agent"): class ConfigSchema(AbstractScriptedAgent.ConfigSchema): """Configuration Schema for Random Agents.""" - agent_name = "Random_Agent" + agent_name: str = "Random_Agent" def get_action(self) -> Tuple[str, Dict]: """Sample the action space randomly. @@ -36,7 +36,7 @@ class PeriodicAgent(AbstractScriptedAgent, identifier="Periodic_Agent"): class ConfigSchema(AbstractScriptedAgent.ConfigSchema): """Configuration Schema for Periodic Agent.""" - agent_name = "Periodic_Agent" + agent_name: str = "Periodic_Agent" """Name of the agent.""" start_step: int = 20 "The timestep at which an agent begins performing it's actions." diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index f307bba5..6cf4a75a 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -174,7 +174,7 @@ class PrimaiteGame: obs = agent.observation_manager.current_observation action_choice, parameters = agent.get_action(obs, timestep=self.step_counter) if SIM_OUTPUT.save_agent_logs: - agent.config.logger.debug(f"Chosen Action: {action_choice}") + agent.logger.debug(f"Chosen Action: {action_choice}") request = agent.format_request(action_choice, parameters) response = self.simulation.apply_request(request) agent.process_action_response( @@ -544,14 +544,20 @@ class PrimaiteGame: # CREATE AGENT - agent_config = agent_cfg.get("agent_settings", {}) - agent_config.update( - {"action_manager": action_space, "observation_manager": obs_space, "reward_function": reward_function} - ) - # new_agent_cfg.update{} - print(AbstractAgent._registry) + agent_settings = agent_cfg["agent_settings"] + agent_config = { + "agent_name": agent_ref, + "action_manager": action_space, + "observation_manager": obs_space, + "reward_function": reward_function, + "agent_settings": agent_settings, + } + # new_agent_cfg.update{} if agent_type in AbstractAgent._registry: + print(agent_type) + print(agent_config) + print(AbstractAgent._registry) new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config) # If blue agent is created, add to game.rl_agents if agent_type == "ProxyAgent": diff --git a/tests/conftest.py b/tests/conftest.py index b693a5e6..68097830 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ # © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK -from typing import Any, Dict, Tuple +from typing import Any, Dict, Optional, Tuple import pytest import yaml @@ -10,6 +10,7 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.agent.scripted_agents.interface import AbstractAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent from primaite.game.game import PrimaiteGame from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.container import Network @@ -268,12 +269,12 @@ class ControlledAgent(AbstractAgent, identifier="Controlled_Agent"): """Agent that can be controlled by the tests.""" config: "ControlledAgent.ConfigSchema" + most_recent_action: Optional[Tuple[str, Dict]] = None class ConfigSchema(AbstractAgent.ConfigSchema): """Configuration Schema for Abstract Agent used in tests.""" agent_name: str = "Controlled_Agent" - most_recent_action: Tuple[str, Dict] def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format.""" @@ -496,12 +497,15 @@ def game_and_agent(): observation_space = ObservationManager(NestedObservation(components={})) reward_function = RewardFunction() - test_agent = ControlledAgent( - agent_name="test_agent", - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - ) + + config = { + "agent_name":"test_agent", + "action_manager":action_space, + "observation_manager":observation_space, + "reward_function":reward_function, + } + + test_agent = ControlledAgent.from_config(config=config) game.agents["test_agent"] = test_agent diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index ec18f1fb..6e8c9c79 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -55,15 +55,25 @@ def test_probabilistic_agent(): observation_space = ObservationManager(NestedObservation(components={})) reward_function = RewardFunction() - pa = ProbabilisticAgent( - agent_name="test_agent", - action_space=action_space, - observation_space=observation_space, - reward_function=reward_function, - settings={ - "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, - }, - ) + # pa = ProbabilisticAgent( + # agent_name="test_agent", + # action_space=action_space, + # observation_space=observation_space, + # reward_function=reward_function, + # settings={ + # "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, + # }, + # ) + + pa_config = {"agent_name":"test_agent", + "action_manager": action_space, + "observation_manager": observation_space, + "reward_function": reward_function, + "agent_settings": { + "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, + }} + + pa = ProbabilisticAgent.from_config(config=pa_config) do_nothing_count = 0 node_application_execute_count = 0