From a5c4f7797d34416fb7bec946a59843fdcbc2ba31 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 4 Dec 2023 10:42:20 +0000 Subject: [PATCH] Make saving step metadata optional --- src/primaite/config/_package_data/example_config.yaml | 1 + .../config/_package_data/example_config_2_rl_agents.yaml | 1 + src/primaite/game/game.py | 9 +++++++++ src/primaite/session/environment.py | 6 ++++-- src/primaite/session/io.py | 2 ++ 5 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 7d5b50d6..24f9945d 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -13,6 +13,7 @@ training_config: io_settings: save_checkpoints: true checkpoint_interval: 5 + save_step_metadata: false game: diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index b811bfa5..9c2acaae 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -9,6 +9,7 @@ training_config: io_settings: save_checkpoints: true checkpoint_interval: 5 + save_step_metadata: false game: diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index a36cbea9..8c32f41d 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NIC, NodeOperatingState from primaite.simulator.network.hardware.nodes.computer import Computer from primaite.simulator.network.hardware.nodes.router import ACLAction, Router @@ -84,6 +85,9 @@ class PrimaiteGame: self.ref_map_links: Dict[str, str] = {} """Mapping from human-readable link reference to link object. Used when parsing config files.""" + self.save_step_metadata: bool = False + """Whether to save the RL agents' action, environment state, and other data at every single step.""" + def step(self): """ Perform one step of the simulation/agent loop. @@ -180,8 +184,13 @@ class PrimaiteGame: :return: A PrimaiteGame object. :rtype: PrimaiteGame """ + io_settings = cfg.get("io_settings", {}) + _ = SessionIO(SessionIOSettings(**io_settings)) + # Instantiating this ensures that the game saves to the correct output dir even without being part of a session + game = cls() game.options = PrimaiteGameOptions(**cfg["game"]) + game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False # 1. create simulation sim = game.simulation diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index dfee9a2f..3d43e338 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -40,7 +40,8 @@ class PrimaiteGymEnv(gymnasium.Env): terminated = False truncated = self.game.calculate_truncated() info = {} - self._write_step_metadata_json(action, state, reward) + if self.game.save_step_metadata: + self._write_step_metadata_json(action, state, reward) print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}") return next_obs, reward, terminated, truncated, info @@ -183,7 +184,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): infos = {} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() - self._write_step_metadata_json(actions, state, rewards) + if self.game.save_step_metadata: + self._write_step_metadata_json(actions, state, rewards) return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index e0b849c9..0d80a385 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -25,6 +25,8 @@ class SessionIOSettings(BaseModel): """Whether to save transactions, If true, the session path will have a transactions folder.""" save_tensorboard_logs: bool = False """Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder.""" + save_step_metadata: bool = False + """Whether to save the RL agents' action, environment state, and other data at every single step.""" class SessionIO: