Make saving step metadata optional
This commit is contained in:
@@ -13,6 +13,7 @@ training_config:
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_step_metadata: false
|
||||
|
||||
|
||||
game:
|
||||
|
||||
@@ -9,6 +9,7 @@ training_config:
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_step_metadata: false
|
||||
|
||||
|
||||
game:
|
||||
|
||||
@@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
@@ -84,6 +85,9 @@ class PrimaiteGame:
|
||||
self.ref_map_links: Dict[str, str] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
|
||||
self.save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
@@ -180,8 +184,13 @@ class PrimaiteGame:
|
||||
:return: A PrimaiteGame object.
|
||||
:rtype: PrimaiteGame
|
||||
"""
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
_ = SessionIO(SessionIOSettings(**io_settings))
|
||||
# Instantiating this ensures that the game saves to the correct output dir even without being part of a session
|
||||
|
||||
game = cls()
|
||||
game.options = PrimaiteGameOptions(**cfg["game"])
|
||||
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
|
||||
|
||||
# 1. create simulation
|
||||
sim = game.simulation
|
||||
|
||||
@@ -40,7 +40,8 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {}
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
print(f"Episode: {self.game.episode_counter}, Step: {self.game.step_counter}, Reward: {reward}")
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
@@ -183,7 +184,8 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
infos = {}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
|
||||
@@ -25,6 +25,8 @@ class SessionIOSettings(BaseModel):
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
||||
save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
|
||||
|
||||
class SessionIO:
|
||||
|
||||
Reference in New Issue
Block a user