diff --git a/CHANGELOG.md b/CHANGELOG.md index 5416bb9d..cdf7b5c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the order of service health state - Fixed an issue where starting a node didn't start the services on it - Added support for SQL INSERT command. +- Added ability to log each agent's action choices in each step to a JSON file. diff --git a/docs/source/configuration/io_settings.rst b/docs/source/configuration/io_settings.rst index 96cc28fe..979dbfae 100644 --- a/docs/source/configuration/io_settings.rst +++ b/docs/source/configuration/io_settings.rst @@ -18,7 +18,7 @@ This section configures how PrimAITE saves data during simulation and training. checkpoint_interval: 10 # save_logs: True # save_transactions: False - # save_tensorboard_logs: False + save_agent_actions: True save_step_metadata: False save_pcap_logs: False save_sys_logs: False @@ -55,15 +55,13 @@ Defines how often to save the policy during training. *currently unused*. -``save_transactions`` ---------------------- -*currently unused*. +``save_agent_actions`` +---------------------- -``save_tensorboard_logs`` -------------------------- +Optional. Default value is ``True``. -*currently unused*. +If ``True``, this will create a JSON file each episode detailing every agent's action in each step of that episode, formatted according to the CAOS format. This includes scripted, RL, and red agents. ``save_step_metadata`` ---------------------- diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index aea5d4fd..d0ba61b0 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -13,6 +13,7 @@ training_config: io_settings: save_checkpoints: true checkpoint_interval: 5 + save_agent_actions: true save_step_metadata: false save_pcap_logs: false save_sys_logs: true diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index b76a4f38..575182a8 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -15,6 +15,7 @@ training_config: io_settings: save_checkpoints: true checkpoint_interval: 5 + save_agent_actions: true save_step_metadata: false save_pcap_logs: false save_sys_logs: true @@ -35,7 +36,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: ProbabilisticAgent + type: ProbabilisticAgent agent_settings: action_probabilities: 0: 0.3 @@ -1283,7 +1284,6 @@ agents: - simulation: network: nmne_config: diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index c758c926..16453433 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -37,7 +37,7 @@ class DataManipulationAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: - return "DONOTHING", {"dummy": 0} + return "DONOTHING", {} self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index cd88d832..394a8154 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -11,7 +11,6 @@ from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAge from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.agent.scripted_agents import ProbabilisticAgent -from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -210,10 +209,6 @@ class PrimaiteGame: :return: A PrimaiteGame object. :rtype: PrimaiteGame """ - io_settings = cfg.get("io_settings", {}) - _ = SessionIO(SessionIOSettings(**io_settings)) - # Instantiating this ensures that the game saves to the correct output dir even without being part of a session - game = cls() game.options = PrimaiteGameOptions(**cfg["game"]) game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False @@ -415,7 +410,7 @@ class PrimaiteGame: # CREATE AGENT if agent_type == "ProbabilisticAgent": # TODO: implement non-random agents and fix this parsing - settings = agent_cfg.get("agent_settings") + settings = agent_cfg.get("agent_settings", {}) new_agent = ProbabilisticAgent( agent_name=agent_cfg["ref"], action_space=action_space, diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 86bc52cb..87638e7d 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -9,6 +9,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv from primaite import getLogger from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.session.io import PrimaiteIO from primaite.simulator import SIM_OUTPUT _LOGGER = getLogger(__name__) @@ -35,6 +36,9 @@ class PrimaiteGymEnv(gymnasium.Env): self.episode_counter: int = 0 """Current episode number.""" + self.io = PrimaiteIO.from_config(game_config.get("io_settings", {})) + """Handles IO for the environment. This produces sys logs, agent logs, etc.""" + @property def agent(self) -> ProxyAgent: """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" @@ -58,6 +62,10 @@ class PrimaiteGymEnv(gymnasium.Env): info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) + if self.io.settings.save_agent_actions: + self.io.store_agent_actions( + agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter + ) return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): @@ -82,6 +90,9 @@ class PrimaiteGymEnv(gymnasium.Env): f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) + if self.io.settings.save_agent_actions: + self.io.write_agent_actions(episode=self.episode_counter) + self.io.clear_agent_actions() self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 @@ -149,7 +160,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): """ self.game_config: Dict = env_config """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) """Reference to the primaite game""" self._agent_ids = list(self.game.rl_agents.keys()) """Agent ids. This is a list of strings of agent names.""" @@ -167,6 +178,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.action_space = gymnasium.spaces.Dict( {name: agent.action_manager.space for name, agent in self.agents.items()} ) + + self.io = PrimaiteIO.from_config(env_config.get("io_settings")) + """Handles IO for the environment. This produces sys logs, agent logs, etc.""" + super().__init__() @property @@ -176,7 +191,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + if self.io.settings.save_agent_actions: + self.io.write_agent_actions(episode=self.episode_counter) + self.io.clear_agent_actions() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state() @@ -199,7 +217,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - self.game.apply_agent_actions() + agent_actions = self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -218,6 +236,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: self._write_step_metadata_json(actions, state, rewards) + if self.io.settings.save_agent_actions: + self.io.store_agent_actions( + agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter + ) return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index b4b740e9..3e21ed16 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -1,55 +1,54 @@ +import json from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional from pydantic import BaseModel, ConfigDict -from primaite import PRIMAITE_PATHS +from primaite import getLogger, PRIMAITE_PATHS from primaite.simulator import SIM_OUTPUT - -class SessionIOSettings(BaseModel): - """Schema for session IO settings.""" - - model_config = ConfigDict(extra="forbid") - - save_final_model: bool = True - """Whether to save the final model right at the end of training.""" - save_checkpoints: bool = False - """Whether to save a checkpoint model every `checkpoint_interval` episodes""" - checkpoint_interval: int = 10 - """How often to save a checkpoint model (if save_checkpoints is True).""" - save_logs: bool = True - """Whether to save logs""" - save_transactions: bool = True - """Whether to save transactions, If true, the session path will have a transactions folder.""" - save_tensorboard_logs: bool = False - """Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder.""" - save_step_metadata: bool = False - """Whether to save the RL agents' action, environment state, and other data at every single step.""" - save_pcap_logs: bool = False - """Whether to save PCAP logs.""" - save_sys_logs: bool = False - """Whether to save system logs.""" +_LOGGER = getLogger(__name__) -class SessionIO: +class PrimaiteIO: """ Class for managing session IO. - Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on. + Currently it's handling path generation, but could expand to handle loading, transaction, and so on. """ - def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: - self.settings: SessionIOSettings = settings + class Settings(BaseModel): + """Config schema for PrimaiteIO object.""" + + model_config = ConfigDict(extra="forbid") + + save_logs: bool = True + """Whether to save logs""" + save_agent_actions: bool = True + """Whether to save a log of all agents' actions every step.""" + save_step_metadata: bool = False + """Whether to save the RL agents' action, environment state, and other data at every single step.""" + save_pcap_logs: bool = False + """Whether to save PCAP logs.""" + save_sys_logs: bool = False + """Whether to save system logs.""" + + def __init__(self, settings: Optional[Settings] = None) -> None: + """ + Init the PrimaiteIO object. + + Note: Instantiating this object creates a new directory for outputs, and sets the global SIM_OUTPUT variable. + It is intended that this object is instantiated when a new environment is created. + """ + self.settings = settings or PrimaiteIO.Settings() self.session_path: Path = self.generate_session_path() # set global SIM_OUTPUT path SIM_OUTPUT.path = self.session_path / "simulation_output" SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs - # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's - # possible refactor needed + self.agent_action_log: List[Dict] = [] def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" @@ -68,3 +67,56 @@ class SessionIO: def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path: """Return the path where the checkpoint model will be saved (excluding filename extension).""" return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt" + + def generate_agent_actions_save_path(self, episode: int) -> Path: + """Return the path where agent actions will be saved.""" + return self.session_path / "agent_actions" / f"episode_{episode}.json" + + def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None: + """Cache agent actions for a particular step. + + :param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected + format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters] + CAOS action is a string representing one the CAOS actions. + parameters is a dict of parameter names and values for that particular CAOS action. + For example: + { + 'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}), + 'defender': ('DO_NOTHING', {}) + } + :type agent_actions: Dict + :param timestep: Simulation timestep when these actions occurred. + :type timestep: int + """ + self.agent_action_log.append( + [ + { + "episode": episode, + "timestep": timestep, + "agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()}, + } + ] + ) + + def write_agent_actions(self, episode: int) -> None: + """Take the contents of the agent action log and write it to a file. + + :param episode: Episode number + :type episode: int + """ + path = self.generate_agent_actions_save_path(episode=episode) + path.parent.mkdir(exist_ok=True, parents=True) + path.touch() + _LOGGER.info(f"Saving agent action log to {path}") + with open(path, "w") as file: + json.dump(self.agent_action_log, fp=file, indent=1) + + def clear_agent_actions(self) -> None: + """Reset the agent action log back to an empty dictionary.""" + self.agent_action_log = [] + + @classmethod + def from_config(cls, config: Dict) -> "PrimaiteIO": + """Create an instance of PrimaiteIO based on a configuration dict.""" + new = cls() + return new diff --git a/src/primaite/session/policy/sb3.py b/src/primaite/session/policy/sb3.py index 254baf4d..6220371d 100644 --- a/src/primaite/session/policy/sb3.py +++ b/src/primaite/session/policy/sb3.py @@ -39,9 +39,9 @@ class SB3Policy(PolicyABC, identifier="SB3"): def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" - if self.session.io_manager.settings.save_checkpoints: + if self.session.save_checkpoints: checkpoint_callback = CheckpointCallback( - save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval, + save_freq=timesteps_per_episode * self.session.checkpoint_interval, save_path=self.session.io_manager.generate_model_save_path("sb3"), name_prefix="sb3_model", ) diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index d244f6b0..9c935ae3 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -1,3 +1,4 @@ +# raise DeprecationWarning("This module is deprecated") from enum import Enum from pathlib import Path from typing import Dict, List, Literal, Optional, Union @@ -5,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv -from primaite.session.io import SessionIO, SessionIOSettings +from primaite.session.io import PrimaiteIO # from primaite.game.game import PrimaiteGame from primaite.session.policy.policy import PolicyABC @@ -53,12 +54,18 @@ class PrimaiteSession: self.policy: PolicyABC """The reinforcement learning policy.""" - self.io_manager: Optional["SessionIO"] = None + self.io_manager: Optional["PrimaiteIO"] = None """IO manager for the session.""" self.game_cfg: Dict = game_cfg """Primaite Game object for managing main simulation loop and agents.""" + self.save_checkpoints: bool = False + """Whether to save checkpoints.""" + + self.checkpoint_interval: int = 10 + """If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes.""" + def start_session(self) -> None: """Commence the training/eval session.""" print("Starting Primaite Session") @@ -89,12 +96,13 @@ class PrimaiteSession: def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary.""" # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... - io_settings = cfg.get("io_settings", {}) - io_manager = SessionIO(SessionIOSettings(**io_settings)) + io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {})) sess = cls(game_cfg=cfg) sess.io_manager = io_manager sess.training_options = TrainingOptions(**cfg["training_config"]) + sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints") + sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval") # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent":