From 7435a4dee8ff59bcafddeffe21c55cc245a63d95 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Thu, 21 Nov 2024 14:45:35 +0000 Subject: [PATCH] #2869 - Commit before changing branches. Addition of properties to Agent classes and removal of if/else chain in game.py --- src/primaite/game/agent/interface.py | 63 +++++---------- .../scripted_agents/data_manipulation_bot.py | 23 ++++-- src/primaite/game/game.py | 80 +++---------------- 3 files changed, 45 insertions(+), 121 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 88557956..962e13f7 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -96,11 +96,6 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): """Base class for scripted and RL agents.""" _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} - - action_manager: Optional[ActionManager] - observation_manager: Optional[ObservationManager] - reward_function: Optional[RewardFunction] - config: "AbstractAgent.ConfigSchema" class ConfigSchema(BaseModel): @@ -121,39 +116,12 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): type: str agent_name: ClassVar[str] - agent_settings = Optional[AgentSettings] = None - history: List[AgentHistoryItem] = [] logger: AgentLog = AgentLog(agent_name) - - # def __init__( - # self, - # agent_name: Optional[str], - # action_space: Optional[ActionManager], - # observation_space: Optional[ObservationManager], - # reward_function: Optional[RewardFunction], - # agent_settings: Optional[AgentSettings] = None, - # ) -> None: - # """ - # Initialize an agent. - - # :param agent_name: Unique string identifier for the agent, for reporting and multi-agent purposes. - # :type agent_name: Optional[str] - # :param action_space: Action space for the agent. - # :type action_space: Optional[ActionManager] - # :param observation_space: Observation space for the agent. - # :type observation_space: Optional[ObservationSpace] - # :param reward_function: Reward function for the agent. - # :type reward_function: Optional[RewardFunction] - # :param agent_settings: Configurable Options for Abstracted Agents - # :type agent_settings: Optional[AgentSettings] - # """ - # self.agent_name: str = agent_name or "unnamed_agent" - # self.action_manager: Optional[ActionManager] = action_space - # self.observation_manager: Optional[ObservationManager] = observation_space - # self.reward_function: Optional[RewardFunction] = reward_function - # self.agent_settings = agent_settings or AgentSettings() - # self.history: List[AgentHistoryItem] = [] - # self.logger = AgentLog(agent_name) + history: List[AgentHistoryItem] = [] + action_manager: Optional[ActionManager] = None + observation_manager: Optional[ObservationManager] = None + reward_function: Optional[RewardFunction] = None + agent_settings = Optional[AgentSettings] = None def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: super().__init_subclass__(**kwargs) @@ -161,16 +129,25 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): raise ValueError(f"Cannot create a new agent under reserved name {identifier}") cls._registry[identifier] = cls + @property + def observation_manager(self) -> ObservationManager: + """Returns the agents observation manager.""" + return self.config.observation_manager + + @property + def action_manager(self) -> ActionManager: + """Returns the agents action manager.""" + return self.config.action_manager + + @property + def reward_function(self) -> RewardFunction: + """Returns the agents reward function.""" + return self.config.reward_function + @classmethod def from_config(cls, config: Dict) -> "AbstractAgent": """Creates an agent component from a configuration dictionary.""" obj = cls(config=cls.ConfigSchema(**config)) - - # Pull managers out of config section for ease of use (?) - obj.observation_manager = obj.config.observation_manager - obj.action_manager = obj.config.action_manager - obj.reward_function = obj.config.reward_function - return obj def update_observation(self, state: Dict) -> ObsType: diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py index 55b2d08b..2f49decd 100644 --- a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -10,9 +10,6 @@ from primaite.game.agent.interface import AbstractScriptedAgent class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation_Agent"): """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" - next_execution_timestep: int = 0 - starting_node_idx: int = 0 - config: "DataManipulationAgent.ConfigSchema" class ConfigSchema(AbstractScriptedAgent.ConfigSchema): @@ -27,13 +24,23 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation super().__init__(*args, **kwargs) self.setup_agent() + @property + def next_execution_timestep(self): + """Returns the agents next execution timestep.""" + return self.config.next_execution_timestep + + @property + def starting_node_name(self): + """Returns the agents starting node name.""" + return self.config.starting_node_name + def _set_next_execution_timestep(self, timestep: int) -> None: """Set the next execution timestep with a configured random variance. :param timestep: The timestep to add variance to. """ random_timestep_increment = random.randint( - -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance + -self.config.agent_settings.start_settings.variance, self.config.agent_settings.start_settings.variance ) self.next_execution_timestep = timestep + random_timestep_increment @@ -48,11 +55,11 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: - self.logger.debug(msg="Performing do nothing action") + self.config.logger.debug(msg="Performing do nothing action") return "do_nothing", {} - self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) - self.logger.info(msg="Performing a data manipulation attack!") + self._set_next_execution_timestep(timestep + self.config.agent_settings.start_settings.frequency) + self.config.logger.info(msg="Performing a data manipulation attack!") return "node_application_execute", { "node_name": self.config.starting_node_name, "application_name": self.config.starting_application_name, @@ -68,4 +75,4 @@ class DataManipulationAgent(AbstractScriptedAgent, identifier="Data_Manipulation # we are assuming that every node in the node manager has a data manipulation application at idx 0 num_nodes = len(self.action_manager.node_names) self.starting_node_idx = random.randint(0, num_nodes - 1) - self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}") + self.config.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}") diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index d7e2ed4a..03f3feec 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -11,10 +11,6 @@ from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction, SharedReward -from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent -from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent -from primaite.game.agent.scripted_agents.tap001 import TAP001 from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator import SIM_OUTPUT from primaite.simulator.network.creation import NetworkNodeAdder @@ -178,7 +174,7 @@ class PrimaiteGame: obs = agent.observation_manager.current_observation action_choice, parameters = agent.get_action(obs, timestep=self.step_counter) if SIM_OUTPUT.save_agent_logs: - agent.logger.debug(f"Chosen Action: {action_choice}") + agent.config.logger.debug(f"Chosen Action: {action_choice}") request = agent.format_request(action_choice, parameters) response = self.simulation.apply_request(request) agent.process_action_response( @@ -548,77 +544,21 @@ class PrimaiteGame: # CREATE AGENT - # TODO: MAKE THIS BIT WORK AND NOT THE IF/ELSE CHAIN OF HORRORS - - # Pass through: - # config - # action manager - # observation_manager - # reward_function agent_config = agent_cfg.get("agent_settings", {}) - agent_config.update({"action_manager": action_space, - "observation_manager": obs_space, - "reward_function":reward_function}) + agent_config.update( + {"action_manager": action_space, "observation_manager": obs_space, "reward_function": reward_function} + ) # new_agent_cfg.update{} - new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config) - - # If blue agent is created, add to game.rl_agents - if agent_type == "ProxyAgent": - game.rl_agents[agent_cfg["ref"]] = new_agent - - if agent_type == "ProbabilisticAgent": - # TODO: implement non-random agents and fix this parsing - settings = agent_cfg.get("agent_settings", {}) - new_agent = ProbabilisticAgent( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - settings=settings, - ) - elif agent_type == "PeriodicAgent": - settings = PeriodicAgent.Settings(**agent_cfg.get("settings", {})) - new_agent = PeriodicAgent( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - settings=settings, - ) - - elif agent_type == "ProxyAgent": - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - new_agent = ProxyAgent( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - agent_settings=agent_settings, - ) - game.rl_agents[agent_cfg["ref"]] = new_agent - elif agent_type == "RedDatabaseCorruptingAgent": - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - - new_agent = DataManipulationAgent( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - agent_settings=agent_settings, - ) - elif agent_type == "TAP001": - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - new_agent = TAP001( - agent_name=agent_cfg["ref"], - action_space=action_space, - observation_space=obs_space, - reward_function=reward_function, - agent_settings=agent_settings, - ) + if agent_type in AbstractAgent._registry: + new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config) + # If blue agent is created, add to game.rl_agents + if agent_type == "ProxyAgent": + game.rl_agents[agent_cfg["ref"]] = new_agent else: msg = f"Configuration error: {agent_type} is not a valid agent type." _LOGGER.error(msg) raise ValueError(msg) + game.agents[agent_cfg["ref"]] = new_agent # Validate that if any agents are sharing rewards, they aren't forming an infinite loop.