From cc04efb31db63f57869c0ce833f30134639f930a Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 1 Dec 2023 16:37:58 +0000 Subject: [PATCH] #2085 - Added step metadata json file dumps to the environments. Fixed serialization issues in the Switch and ACLRule classes. --- docs/source/primaite_session.rst | 4 +- src/primaite/session/environment.py | 47 ++++++++++++++----- .../network/hardware/nodes/router.py | 4 +- .../network/hardware/nodes/switch.py | 2 +- 4 files changed, 41 insertions(+), 16 deletions(-) diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index f3ef0399..706397b6 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -31,4 +31,6 @@ Outputs Running a session creates a session output directory in your user data folder. The filepath looks like this: ``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node, -the saved agent checkpoints, and final model. +the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that +contains the action, reward, and simulation state. These can be found in +``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_/step_metadata/step_.json`` diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 3c164878..9c86aee0 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -5,9 +5,9 @@ import gymnasium from gymnasium.core import ActType, ObsType from ray.rllib.env.multi_agent_env import MultiAgentEnv -from primaite import PRIMAITE_PATHS from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.simulator import SIM_OUTPUT class PrimaiteGymEnv(gymnasium.Env): @@ -33,17 +33,6 @@ class PrimaiteGymEnv(gymnasium.Env): self.game.advance_timestep() state = self.game.get_sim_state() - # Create state suitable for dumping to file. - # dump_state = {self.game.episode_counter: {self.game.step_counter: state}} - - # Dump to file - # if os.path.isfile(PRIMAITE_PATHS.episode_steps_log_file_path): - with open(PRIMAITE_PATHS.episode_log_file_path, "a", encoding="utf-8") as f: - # f.write(str(dump_state)) - # f.write("\n=================\n") - # f.flush() - json.dump(state, f) - self.game.update_agents(state) next_obs = self._get_obs() @@ -51,9 +40,26 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = {} + self._write_step_metadata_json(action, state, reward) print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}") return next_obs, reward, terminated, truncated, info + def _write_step_metadata_json(self, action: int, state: Dict, reward: int): + output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"step_{self.game.step_counter}.json" + + data = { + "episode": self.game.episode_counter, + "step": self.game.step_counter, + "action": int(action), + "reward": int(reward), + "state": state, + } + with open(path, "w") as file: + json.dump(data, file) + def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" self.game.reset() @@ -173,8 +179,25 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): infos = {} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() + self._write_step_metadata_json(actions, state, rewards) return next_obs, rewards, terminateds, truncateds, infos + def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): + output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + + output_dir.mkdir(parents=True, exist_ok=True) + path = output_dir / f"step_{self.game.step_counter}.json" + + data = { + "episode": self.game.episode_counter, + "step": self.game.step_counter, + "actions": {agent_name: int(action) for agent_name, action in actions.items()}, + "reward": rewards, + "state": state, + } + with open(path, "w") as file: + json.dump(data, file) + def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" obs = {} diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index 0017215a..0234934d 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -66,9 +66,9 @@ class ACLRule(SimComponent): state = super().describe_state() state["action"] = self.action.value state["protocol"] = self.protocol.value if self.protocol else None - state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None + state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None state["src_port"] = self.src_port.value if self.src_port else None - state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None + state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None state["dst_port"] = self.dst_port.value if self.dst_port else None return state diff --git a/src/primaite/simulator/network/hardware/nodes/switch.py b/src/primaite/simulator/network/hardware/nodes/switch.py index fe61509c..92999b88 100644 --- a/src/primaite/simulator/network/hardware/nodes/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/switch.py @@ -57,7 +57,7 @@ class Switch(Node): state = super().describe_state() state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()} state["num_ports"] = self.num_ports # redundant? - state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()} + state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()} return state def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):