From 5bda952ead90e566593681eefdfa9d223c84af3a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 17 Nov 2023 10:20:26 +0000 Subject: [PATCH] Fix sim output --- src/primaite/game/io.py | 5 ++++ src/primaite/game/session.py | 9 +++--- src/primaite/simulator/__init__.py | 30 +++++++++++++------ .../simulator/network/hardware/base.py | 2 +- .../simulator/system/core/packet_capture.py | 2 +- src/primaite/simulator/system/core/sys_log.py | 3 +- 6 files changed, 34 insertions(+), 17 deletions(-) diff --git a/src/primaite/game/io.py b/src/primaite/game/io.py index d510d108..e0b849c9 100644 --- a/src/primaite/game/io.py +++ b/src/primaite/game/io.py @@ -5,6 +5,7 @@ from typing import Optional from pydantic import BaseModel, ConfigDict from primaite import PRIMAITE_PATHS +from primaite.simulator import SIM_OUTPUT class SessionIOSettings(BaseModel): @@ -36,6 +37,10 @@ class SessionIO: def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: self.settings: SessionIOSettings = settings self.session_path: Path = self.generate_session_path() + + # set global SIM_OUTPUT path + SIM_OUTPUT.path = self.session_path / "simulation_output" + # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's # possible refactor needed diff --git a/src/primaite/game/session.py b/src/primaite/game/session.py index 655e2459..a2c04980 100644 --- a/src/primaite/game/session.py +++ b/src/primaite/game/session.py @@ -324,6 +324,11 @@ class PrimaiteSession: protocols=cfg["game_config"]["protocols"], ) sess.training_options = TrainingOptions(**cfg["training_config"]) + + # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... + io_settings = cfg.get("io_settings", {}) + sess.io_manager.settings = SessionIOSettings(**io_settings) + sim = sess.simulation net = sim.network @@ -524,8 +529,4 @@ class PrimaiteSession: if agent_load_path: sess.policy.load(Path(agent_load_path)) - # READ IO SETTINGS - io_settings = cfg.get("io_settings", {}) - sess.io_manager.settings = SessionIOSettings(**io_settings) - return sess diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index 8c55542f..19c86e28 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -1,14 +1,26 @@ +"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory.""" from datetime import datetime +from pathlib import Path from primaite import _PRIMAITE_ROOT -SIM_OUTPUT = None -"A path at the repo root dir to use temporarily for sim output testing while in dev." -# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path +__all__ = ["SIM_OUTPUT"] -if not SIM_OUTPUT: - session_timestamp = datetime.now() - date_dir = session_timestamp.strftime("%Y-%m-%d") - sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path - SIM_OUTPUT.mkdir(exist_ok=True, parents=True) + +class __SimOutput: + def __init__(self): + self._path: Path = ( + _PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ) + + @property + def path(self) -> Path: + return self._path + + @path.setter + def path(self, new_path: Path) -> None: + self._path = new_path + self._path.mkdir(exist_ok=True, parents=True) + + +SIM_OUTPUT = __SimOutput() diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 537cebb2..29d3a05c 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -957,7 +957,7 @@ class Node(SimComponent): if not kwargs.get("session_manager"): kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp")) if not kwargs.get("root"): - kwargs["root"] = SIM_OUTPUT / kwargs["hostname"] + kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"] if not kwargs.get("file_system"): kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs") if not kwargs.get("software_manager"): diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index 2e5ed008..c2faeb10 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -75,7 +75,7 @@ class PacketCapture: def _get_log_path(self) -> Path: """Get the path for the log file.""" - root = SIM_OUTPUT / self.hostname + root = SIM_OUTPUT.path / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self._logger_name}.log" diff --git a/src/primaite/simulator/system/core/sys_log.py b/src/primaite/simulator/system/core/sys_log.py index 791e0be8..7ac6df85 100644 --- a/src/primaite/simulator/system/core/sys_log.py +++ b/src/primaite/simulator/system/core/sys_log.py @@ -41,7 +41,6 @@ class SysLog: JSON-like messages. """ log_path = self._get_log_path() - file_handler = logging.FileHandler(filename=log_path) file_handler.setLevel(logging.DEBUG) @@ -81,7 +80,7 @@ class SysLog: :return: Path object representing the location of the log file. """ - root = SIM_OUTPUT / self.hostname + root = SIM_OUTPUT.path / self.hostname root.mkdir(exist_ok=True, parents=True) return root / f"{self.hostname}_sys.log"