From 75d4ef2dfd0ac0e9a5bc9cc45f96f41bd8f04977 Mon Sep 17 00:00:00 2001 From: Charlie Crane Date: Wed, 20 Nov 2024 17:51:05 +0000 Subject: [PATCH] #2869 - eod commit. Updates to AbstractAgent.from_config, and some minor tweaks to PrimaiteGame --- src/primaite/game/agent/interface.py | 22 +++++++++++++------ .../scripted_agents/probabilistic_agent.py | 4 ++-- .../game/agent/scripted_agents/tap001.py | 2 +- src/primaite/game/game.py | 19 ++++++++-------- 4 files changed, 27 insertions(+), 20 deletions(-) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 7adaab69..88557956 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -97,11 +97,12 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): _registry: ClassVar[Dict[str, Type[AbstractAgent]]] = {} - config: "AbstractAgent.ConfigSchema" action_manager: Optional[ActionManager] observation_manager: Optional[ObservationManager] reward_function: Optional[RewardFunction] + config: "AbstractAgent.ConfigSchema" + class ConfigSchema(BaseModel): """ Configuration Schema for AbstractAgents. @@ -163,7 +164,14 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): @classmethod def from_config(cls, config: Dict) -> "AbstractAgent": """Creates an agent component from a configuration dictionary.""" - return cls(config=cls.ConfigSchema(**config)) + obj = cls(config=cls.ConfigSchema(**config)) + + # Pull managers out of config section for ease of use (?) + obj.observation_manager = obj.config.observation_manager + obj.action_manager = obj.config.action_manager + obj.reward_function = obj.config.reward_function + + return obj def update_observation(self, state: Dict) -> ObsType: """ @@ -172,7 +180,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): state : dict state directly from simulation.describe_state output : dict state according to CAOS. """ - return self.config.observation_manager.update(state) + return self.observation_manager.update(state) def update_reward(self, state: Dict) -> float: """ @@ -183,7 +191,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): :return: Reward from the state. :rtype: float """ - return self.config.reward_function.update(state=state, last_action_response=self.history[-1]) + return self.reward_function.update(state=state, last_action_response=self.config.history[-1]) @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -201,13 +209,13 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): """ # in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39, # then use a bespoke conversion to take 1-40 int back into CAOS action - return ("DO_NOTHING", {}) + return ("do_nothing", {}) def format_request(self, action: Tuple[str, Dict], options: Dict[str, int]) -> List[str]: # this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator. # therefore the execution definition needs to be a mapping from CAOS into SIMULATOR """Format action into format expected by the simulator, and apply execution definition if applicable.""" - request = self.config.action_manager.form_request(action_identifier=action, action_options=options) + request = self.action_manager.form_request(action_identifier=action, action_options=options) return request def process_action_response( @@ -222,7 +230,7 @@ class AbstractAgent(BaseModel, ABC, identifier="Abstract_Agent"): def save_reward_to_history(self) -> None: """Update the most recent history item with the reward value.""" - self.config.history[-1].reward = self.config.reward_function.current_reward + self.config.history[-1].reward = self.reward_function.current_reward class AbstractScriptedAgent(AbstractAgent, identifier="Abstract_Scripted_Agent"): diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py index b8df7838..02ac5931 100644 --- a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -15,7 +15,7 @@ from primaite.game.agent.rewards import RewardFunction class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent"): """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" - class Settings(pydantic.BaseModel): + class ConfigSchema(pydantic.BaseModel): """Config schema for Probabilistic agent settings.""" model_config = pydantic.ConfigDict(extra="forbid") @@ -60,7 +60,7 @@ class ProbabilisticAgent(AbstractScriptedAgent, identifier="Probabilistic_Agent" # The random number seed for np.random is dependent on whether a random number seed is set # in the config file. If there is one it is processed by set_random_seed() in environment.py # and as a consequence the the sequence of rng_seed's used here will be repeatable. - self.settings = ProbabilisticAgent.Settings(**settings) + self.settings = ProbabilisticAgent.ConfigSchema(**settings) rng_seed = np.random.randint(0, 65535) self.rng = np.random.default_rng(rng_seed) diff --git a/src/primaite/game/agent/scripted_agents/tap001.py b/src/primaite/game/agent/scripted_agents/tap001.py index 78cb9293..7365fd88 100644 --- a/src/primaite/game/agent/scripted_agents/tap001.py +++ b/src/primaite/game/agent/scripted_agents/tap001.py @@ -16,7 +16,7 @@ class TAP001(AbstractScriptedAgent, identifier="TAP001"): Scripted Red Agent. Capable of one action; launching the kill-chain (Ransomware Application) """ - # TODO: Link with DataManipulationAgent via a parent "TAP" agent class. + # TODO: Link with DataManipulationAgent class via a parent "TAP" agent class. config: "TAP001.ConfigSchema" diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index 2ef7b1c5..d7e2ed4a 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -555,17 +555,16 @@ class PrimaiteGame: # action manager # observation_manager # reward_function - - new_agent_cfg = { - "action_manager": action_space, - "agent_name": agent_cfg["ref"], - "observation_manager": obs_space, - "agent_settings": agent_cfg.get("agent_settings", {}), - "reward_function": reward_function, - } - new_agent_cfg = agent_cfg["settings"] + agent_config = agent_cfg.get("agent_settings", {}) + agent_config.update({"action_manager": action_space, + "observation_manager": obs_space, + "reward_function":reward_function}) # new_agent_cfg.update{} - new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=new_agent_cfg) + new_agent = AbstractAgent._registry[agent_cfg["type"]].from_config(config=agent_config) + + # If blue agent is created, add to game.rl_agents + if agent_type == "ProxyAgent": + game.rl_agents[agent_cfg["ref"]] = new_agent if agent_type == "ProbabilisticAgent": # TODO: implement non-random agents and fix this parsing