Make saving step metadata optional

This commit is contained in:
Marek Wolan
2023-12-04 10:42:20 +00:00
parent 8ea9db2d34
commit a5c4f7797d
5 changed files with 17 additions and 2 deletions

View File

@@ -13,6 +13,7 @@ training_config:
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
game:

View File

@@ -9,6 +9,7 @@ training_config:
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
game:

View File

@@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -84,6 +85,9 @@ class PrimaiteGame:
self.ref_map_links: Dict[str, str] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -180,8 +184,13 @@ class PrimaiteGame:
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
io_settings = cfg.get("io_settings", {})
_ = SessionIO(SessionIOSettings(**io_settings))
# Instantiating this ensures that the game saves to the correct output dir even without being part of a session
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
# 1. create simulation
sim = game.simulation

View File

@@ -40,7 +40,8 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {}
self._write_step_metadata_json(action, state, reward)
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}")
return next_obs, reward, terminated, truncated, info
@@ -183,7 +184,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
infos = {}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
self._write_step_metadata_json(actions, state, rewards)
if self.game.save_step_metadata:
self._write_step_metadata_json(actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):

View File

@@ -25,6 +25,8 @@ class SessionIOSettings(BaseModel):
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_tensorboard_logs: bool = False
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
class SessionIO: