From 759965587982931bd039df6a4084ff7aa364cbdd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 11 Mar 2024 20:10:08 +0000 Subject: [PATCH] Add agent action history --- .../game/agent/data_manipulation_bot.py | 5 +- src/primaite/game/agent/interface.py | 37 ++++++++++++-- src/primaite/game/agent/rewards.py | 37 ++++++++++---- src/primaite/game/game.py | 51 +++++++++---------- src/primaite/interface/request.py | 4 +- .../Data-Manipulation-E2E-Demonstration.ipynb | 10 ++-- src/primaite/session/environment.py | 25 ++++----- src/primaite/session/io.py | 41 +++------------ src/primaite/simulator/core.py | 4 +- 9 files changed, 110 insertions(+), 104 deletions(-) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index 16453433..d3ec19cb 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -14,7 +14,7 @@ class DataManipulationAgent(AbstractScriptedAgent): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.reset_agent_for_episode() + self.setup_agent() def _set_next_execution_timestep(self, timestep: int) -> None: """Set the next execution timestep with a configured random variance. @@ -43,9 +43,8 @@ class DataManipulationAgent(AbstractScriptedAgent): return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} - def reset_agent_for_episode(self) -> None: + def setup_agent(self) -> None: """Set the next execution timestep when the episode resets.""" - super().reset_agent_for_episode() self._select_start_node() self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 88848479..0531b25f 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,6 +1,6 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium.core import ActType, ObsType from pydantic import BaseModel, model_validator @@ -8,11 +8,31 @@ from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.interface.request import RequestFormat, RequestResponse if TYPE_CHECKING: pass +class AgentActionHistoryItem(BaseModel): + """One entry of an agent's action log - what the agent did and how the simulator responded in 1 step.""" + + timestep: int + """Timestep of this action.""" + + action: str + """CAOS Action name.""" + + parameters: Dict[str, Any] + """CAOS parameters for the given action.""" + + request: RequestFormat + """The request that was sent to the simulation based on the CAOS action chosen.""" + + response: RequestResponse + """The response sent back by the simulator for this action.""" + + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -90,6 +110,7 @@ class AbstractAgent(ABC): self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() + self.action_history: List[AgentActionHistoryItem] = [] def update_observation(self, state: Dict) -> ObsType: """ @@ -109,7 +130,7 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.update(state) + return self.reward_function.update(state=state, last_action_response=self.action_history[-1]) @abstractmethod def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: @@ -138,9 +159,15 @@ class AbstractAgent(ABC): request = self.action_manager.form_request(action_identifier=action, action_options=options) return request - def reset_agent_for_episode(self) -> None: - """Agent reset logic should go here.""" - pass + def process_action_response( + self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + ) -> None: + """Process the response from the most recent action.""" + self.action_history.append( + AgentActionHistoryItem( + timestep=timestep, action=action, parameters=parameters, request=request, response=response + ) + ) class AbstractScriptedAgent(AbstractAgent): diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index 8c8e36ad..6ab5aa42 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -26,11 +26,14 @@ the structure: ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, Type +from typing import Dict, List, Tuple, Type, TYPE_CHECKING from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.agent.interface import AgentActionHistoryItem + _LOGGER = getLogger(__name__) @@ -38,7 +41,9 @@ class AbstractReward: """Base class for reward function components.""" @abstractmethod - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the reward for the current state.""" return 0.0 @@ -58,7 +63,9 @@ class AbstractReward: class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the reward for the current state.""" return 0.0 @@ -98,7 +105,9 @@ class DatabaseFileIntegrity(AbstractReward): file_name, ] - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -153,7 +162,9 @@ class WebServer404Penalty(AbstractReward): """ self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -206,7 +217,9 @@ class WebpageUnavailablePenalty(AbstractReward): self._node = node_hostname self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"] - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """ Calculate the reward based on current simulation state. @@ -255,13 +268,17 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward): self._node = node_hostname self.location_in_state = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] - def calculate(self, state: Dict) -> float: + def calculate( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """ Calculate the reward based on current simulation state. :param state: The current state of the simulation. :type state: Dict """ + if last_action_response.request == ["network", "node", "client_2", "application", "DatabaseClient", "execute"]: + pass # TODO db_state = access_from_nested_dict(state, self.location_in_state) if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") @@ -313,7 +330,9 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def update(self, state: Dict) -> float: + def update( + self, state: Dict, last_action_response: "AgentActionHistoryItem" + ) -> float: # todo maybe make last_action_response optional? """Calculate the overall reward for the current state. :param state: The current state of the simulation. @@ -323,7 +342,7 @@ class RewardFunction: for comp_and_weight in self.reward_components: comp = comp_and_weight[0] weight = comp_and_weight[1] - total += weight * comp.calculate(state=state) + total += weight * comp.calculate(state=state, last_action_response=last_action_response) self.current_reward = total return self.current_reward diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index c94cb3ad..1cc8cfed 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,6 +1,6 @@ """PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address -from typing import Dict, List, Tuple +from typing import Dict, List from pydantic import BaseModel, ConfigDict @@ -130,49 +130,44 @@ class PrimaiteGame: """ _LOGGER.debug(f"Stepping. Step counter: {self.step_counter}") - # Get the current state of the simulation - sim_state = self.get_sim_state() - - # Update agents' observations and rewards based on the current state - self.update_agents(sim_state) - # Apply all actions to simulation as requests - self.apply_agent_actions() + action_data = self.apply_agent_actions() # Advance timestep self.advance_timestep() + # Get the current state of the simulation + sim_state = self.get_sim_state() + + # Update agents' observations and rewards based on the current state, and the response from the last action + self.update_agents(state=sim_state, action_data=action_data) + def get_sim_state(self) -> Dict: """Get the current state of the simulation.""" return self.simulation.describe_state() def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for _, agent in self.agents.items(): - agent.update_observation(state) - agent.update_reward(state) + for agent_name, agent in self.agents.items(): + if self.step_counter > 0: # can't get reward before first action + agent.update_reward(state=state) + agent.update_observation(state=state) agent.reward_function.total_reward += agent.reward_function.current_reward - def apply_agent_actions(self) -> Dict[str, Tuple[str, Dict]]: - """ - Apply all actions to simulation as requests. - - :return: A recap of each agent's actions, in CAOS format. - :rtype: Dict[str, Tuple[str, Dict]] - - """ - agent_actions = {} + def apply_agent_actions(self) -> None: + """Apply all actions to simulation as requests.""" for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation - action_choice, options = agent.get_action(obs, timestep=self.step_counter) - request = agent.format_request(action_choice, options) + action_choice, parameters = agent.get_action(obs, timestep=self.step_counter) + request = agent.format_request(action_choice, parameters) response = self.simulation.apply_request(request) - agent_actions[agent.agent_name] = { - "action": action_choice, - "parameters": options, - "response": response.model_dump(), - } - return agent_actions + agent.process_action_response( + timestep=self.step_counter, + action=action_choice, + parameters=parameters, + request=request, + response=response, + ) def advance_timestep(self) -> None: """Advance timestep.""" diff --git a/src/primaite/interface/request.py b/src/primaite/interface/request.py index 8e922ef9..bc076599 100644 --- a/src/primaite/interface/request.py +++ b/src/primaite/interface/request.py @@ -1,7 +1,9 @@ -from typing import Dict, ForwardRef, Literal +from typing import Dict, ForwardRef, List, Literal, Union from pydantic import BaseModel, ConfigDict, StrictBool, validate_call +RequestFormat = List[Union[str, int, float]] + RequestResponse = ForwardRef("RequestResponse") """This makes it possible to type-hint RequestResponse.from_bool return type.""" diff --git a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 1d7cb157..b2522c2b 100644 --- a/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -373,7 +373,7 @@ "# Imports\n", "from primaite.config.load import data_manipulation_config_path\n", "from primaite.session.environment import PrimaiteGymEnv\n", - "from primaite.game.game import PrimaiteGame\n", + "from primaite.game.agent.interface import AgentActionHistoryItem\n", "import yaml\n", "from pprint import pprint\n" ] @@ -425,14 +425,14 @@ "source": [ "def friendly_output_red_action(info):\n", " # parse the info dict form step output and write out what the red agent is doing\n", - " red_info = info['agent_actions']['data_manipulation_attacker']\n", - " red_action = red_info['action']\n", + " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_action = red_info.action\n", " if red_action == 'DONOTHING':\n", " red_str = 'DO NOTHING'\n", " elif red_action == 'NODE_APPLICATION_EXECUTE':\n", - " client = \"client 1\" if red_info['parameters']['node_id'] == 0 else \"client 2\"\n", + " client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n", " red_str = f\"ATTACK from {client}\"\n", - " return red_str\n" + " return red_str" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 87638e7d..64534b04 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -49,23 +49,20 @@ class PrimaiteGymEnv(gymnasium.Env): # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) # apply_agent_actions accesses the action we just stored - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() - self.game.update_agents(state) - next_obs = self._get_obs() + next_obs = self._get_obs() # this doesn't update observation, just gets the current observation reward = self.agent.reward_function.current_reward terminated = False truncated = self.game.calculate_truncated() - info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience. + info = { + "agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()} + } # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) - if self.io.settings.save_agent_actions: - self.io.store_agent_actions( - agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter - ) return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): @@ -91,13 +88,13 @@ class PrimaiteGymEnv(gymnasium.Env): f"avg. reward: {self.agent.reward_function.total_reward}" ) if self.io.settings.save_agent_actions: - self.io.write_agent_actions(episode=self.episode_counter) - self.io.clear_agent_actions() + all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} + self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state() - self.game.update_agents(state) + self.game.update_agents(state=state) next_obs = self._get_obs() info = {} return next_obs, info @@ -217,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -236,10 +233,6 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: self._write_step_metadata_json(actions, state, rewards) - if self.io.settings.save_agent_actions: - self.io.store_agent_actions( - agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter - ) return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index ed2b4d62..87289c43 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -48,8 +48,6 @@ class PrimaiteIO: SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs - self.agent_action_log: List[Dict] = [] - def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" if timestamp is None: @@ -72,48 +70,23 @@ class PrimaiteIO: """Return the path where agent actions will be saved.""" return self.session_path / "agent_actions" / f"episode_{episode}.json" - def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None: - """Cache agent actions for a particular step. - - :param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected - format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters] - CAOS action is a string representing one the CAOS actions. - parameters is a dict of parameter names and values for that particular CAOS action. - For example: - { - 'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}), - 'defender': ('DO_NOTHING', {}) - } - :type agent_actions: Dict - :param timestep: Simulation timestep when these actions occurred. - :type timestep: int - """ - self.agent_action_log.append( - [ - { - "episode": episode, - "timestep": timestep, - "agent_actions": agent_actions, - } - ] - ) - - def write_agent_actions(self, episode: int) -> None: + def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None: """Take the contents of the agent action log and write it to a file. :param episode: Episode number :type episode: int """ + data = {} + longest_history = max([len(hist) for hist in agent_actions]) + for i in range(longest_history): + data[i] = {"timestep": i, "episode": episode, **{name: acts[i] for name, acts in agent_actions.items()}} + path = self.generate_agent_actions_save_path(episode=episode) path.parent.mkdir(exist_ok=True, parents=True) path.touch() _LOGGER.info(f"Saving agent action log to {path}") with open(path, "w") as file: - json.dump(self.agent_action_log, fp=file, indent=1) - - def clear_agent_actions(self) -> None: - """Reset the agent action log back to an empty dictionary.""" - self.agent_action_log = [] + json.dump(data, fp=file, indent=1) @classmethod def from_config(cls, config: Dict) -> "PrimaiteIO": diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index aeb4e865..6da8a2f8 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -7,12 +7,10 @@ from uuid import uuid4 from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger -from primaite.interface.request import RequestResponse +from primaite.interface.request import RequestFormat, RequestResponse _LOGGER = getLogger(__name__) -RequestFormat = List[Union[str, int, float]] - class RequestPermissionValidator(BaseModel): """