diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index f3ef0399..706397b6 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -31,4 +31,6 @@ Outputs Running a session creates a session output directory in your user data folder. The filepath looks like this: ``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node, -the saved agent checkpoints, and final model. +the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that +contains the action, reward, and simulation state. These can be found in +``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_/step_metadata/step_.json`` diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 28245d33..c58f0103 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -1,4 +1,5 @@ # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +import datetime as datetime import logging import logging.config import shutil @@ -38,6 +39,7 @@ class _PrimaitePaths: self.app_config_file_path = self.generate_app_config_file_path() self.app_log_dir_path = self.generate_app_log_dir_path() self.app_log_file_path = self.generate_app_log_file_path() + self.episode_log_file_path = self.generate_episode_log_file_path() def _get_dirs_properties(self) -> List[str]: class_items = self.__class__.__dict__.items() @@ -105,6 +107,13 @@ class _PrimaitePaths: """The PrimAITE app log file path.""" return self.app_log_dir_path / "primaite.log" + def generate_episode_log_file_path(self) -> Path: + """The PrimAITE app episode step log file path.""" + date_string = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") + self.episode_log_dir_path = self.app_log_dir_path / date_string + self.episode_log_dir_path.mkdir(exist_ok=True, parents=True) + return self.episode_log_dir_path / "episode.log" + def __repr__(self) -> str: properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()]) return f"{self.__class__.__name__}({properties_str})" diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 7d5b50d6..24f9945d 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -13,6 +13,7 @@ training_config: io_settings: save_checkpoints: true checkpoint_interval: 5 + save_step_metadata: false game: diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index b811bfa5..9c2acaae 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -9,6 +9,7 @@ training_config: io_settings: save_checkpoints: true checkpoint_interval: 5 + save_step_metadata: false game: diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index a36cbea9..8c32f41d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NIC, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -84,6 +85,9 @@ class PrimaiteGame: self.ref_map_links: Dict[str, str] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" + self.save_step_metadata: bool = False + """Whether to save the RL agents' action, environment state, and other data at every single step.""" + def step(self): """ Perform one step of the simulation/agent loop. @@ -180,8 +184,13 @@ class PrimaiteGame: :return: A PrimaiteGame object. :rtype: PrimaiteGame """ + io_settings = cfg.get("io_settings", {}) + _ = SessionIO(SessionIOSettings(**io_settings)) + # Instantiating this ensures that the game saves to the correct output dir even without being part of a session + game = cls() game.options = PrimaiteGameOptions(**cfg["game"]) + game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False # 1. create simulation sim = game.simulation diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index c2f19f36..ca71a0c0 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple import gymnasium @@ -6,6 +7,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.simulator import SIM_OUTPUT class PrimaiteGymEnv(gymnasium.Env): @@ -30,6 +32,7 @@ class PrimaiteGymEnv(gymnasium.Env): self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() + self.game.update_agents(state) next_obs = self._get_obs() @@ -37,8 +40,26 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = {} + if self.game.save_step_metadata: + self._write_step_metadata_json(action, state, reward) return next_obs, reward, terminated, truncated, info + def _write_step_metadata_json(self, action: int, state: Dict, reward: int): + output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"step_{self.game.step_counter}.json" + + data = { + "episode": self.game.episode_counter, + "step": self.game.step_counter, + "action": int(action), + "reward": int(reward), + "state": state, + } + with open(path, "w") as file: + json.dump(data, file) + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" print( @@ -162,8 +183,26 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): infos = {} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() + if self.game.save_step_metadata: + self._write_step_metadata_json(actions, state, rewards) return next_obs, rewards, terminateds, truncateds, infos + def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): + output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"step_{self.game.step_counter}.json" + + data = { + "episode": self.game.episode_counter, + "step": self.game.step_counter, + "actions": {agent_name: int(action) for agent_name, action in actions.items()}, + "reward": rewards, + "state": state, + } + with open(path, "w") as file: + json.dump(data, file) + def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" obs = {} diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index e0b849c9..0d80a385 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -25,6 +25,8 @@ class SessionIOSettings(BaseModel): """Whether to save transactions, If true, the session path will have a transactions folder.""" save_tensorboard_logs: bool = False """Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder.""" + save_step_metadata: bool = False + """Whether to save the RL agents' action, environment state, and other data at every single step.""" class SessionIO: diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 0017215a..0234934d 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -66,9 +66,9 @@ class ACLRule(SimComponent): state = super().describe_state() state["action"] = self.action.value state["protocol"] = self.protocol.value if self.protocol else None - state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None + state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None state["src_port"] = self.src_port.value if self.src_port else None - state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None + state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None state["dst_port"] = self.dst_port.value if self.dst_port else None return state diff --git a/src/primaite/simulator/network/hardware/nodes/switch.py b/src/primaite/simulator/network/hardware/nodes/switch.py index fe61509c..92999b88 100644 --- a/src/primaite/simulator/network/hardware/nodes/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/switch.py @@ -57,7 +57,7 @@ class Switch(Node): state = super().describe_state() state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()} state["num_ports"] = self.num_ports # redundant? - state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()} + state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} return state def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):