Files
PrimAITE/src/primaite/session/io.py

134 lines
5.9 KiB
Python

# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import json
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
from primaite import _PRIMAITE_ROOT, getLogger, PRIMAITE_CONFIG, PRIMAITE_PATHS
from primaite.simulator import LogLevel, SIM_OUTPUT
from primaite.utils.cli.primaite_config_utils import is_dev_mode
_LOGGER = getLogger(__name__)
class PrimaiteIO:
"""
Class for managing session IO.
Currently it's handling path generation, but could expand to handle loading, transaction, and so on.
"""
class Settings(BaseModel):
"""Config schema for PrimaiteIO object."""
model_config = ConfigDict(extra="forbid")
save_logs: bool = True
"""Whether to save logs"""
save_agent_actions: bool = True
"""Whether to save a log of all agents' actions every step."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
save_pcap_logs: bool = True
"""Whether to save PCAP logs."""
save_sys_logs: bool = True
"""Whether to save system logs."""
save_agent_logs: bool = True
"""Whether to save agent logs."""
write_sys_log_to_terminal: bool = False
"""Whether to write the sys log to the terminal."""
write_agent_log_to_terminal: bool = False
"""Whether to write the agent log to the terminal."""
sys_log_level: LogLevel = LogLevel.INFO
"""The level of sys logs that should be included in the logfiles/logged into terminal."""
agent_log_level: LogLevel = LogLevel.INFO
"""The level of agent logs that should be included in the logfiles/logged into terminal."""
def __init__(self, settings: Optional[Settings] = None) -> None:
"""
Init the PrimaiteIO object.
Note: Instantiating this object creates a new directory for outputs, and sets the global SIM_OUTPUT variable.
It is intended that this object is instantiated when a new environment is created.
"""
self.settings = settings or PrimaiteIO.Settings()
self.session_path: Path = self.generate_session_path()
# set global SIM_OUTPUT path
SIM_OUTPUT.path = self.session_path / "simulation_output"
SIM_OUTPUT.agent_behaviour_path = self.session_path / "agent_behaviour"
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
SIM_OUTPUT.save_agent_logs = self.settings.save_agent_logs
SIM_OUTPUT.write_agent_log_to_terminal = self.settings.write_agent_log_to_terminal
SIM_OUTPUT.write_sys_log_to_terminal = self.settings.write_sys_log_to_terminal
SIM_OUTPUT.sys_log_level = self.settings.sys_log_level
SIM_OUTPUT.agent_log_level = self.settings.agent_log_level
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
session_path = PRIMAITE_PATHS.user_sessions_path / SIM_OUTPUT.date_str / SIM_OUTPUT.time_str
# check if running in dev mode
if is_dev_mode():
session_path = _PRIMAITE_ROOT.parent.parent / "sessions" / SIM_OUTPUT.date_str / SIM_OUTPUT.time_str
# check if there is an output directory set in config
if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
session_path = (
Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"])
/ "sessions"
/ SIM_OUTPUT.date_str
/ SIM_OUTPUT.time_str
)
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def generate_model_save_path(self, agent_name: str) -> Path:
"""Return the path where the final model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_final"
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
def generate_agent_actions_save_path(self, episode: int) -> Path:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number
:type episode: int
"""
data = {}
longest_history = max([len(hist) for hist in agent_actions.values()])
for i in range(longest_history):
data[i] = {"timestep": i, "episode": episode}
data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i})
path = self.generate_agent_actions_save_path(episode=episode)
path.parent.mkdir(exist_ok=True, parents=True)
path.touch()
_LOGGER.info(f"Saving agent action log to {path}")
with open(path, "w") as file:
json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump())
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":
"""Create an instance of PrimaiteIO based on a configuration dict."""
config = config or {}
if config.get("sys_log_level"):
config["sys_log_level"] = LogLevel[config["sys_log_level"].upper()] # convert to enum
if config.get("agent_log_level"):
config["agent_log_level"] = LogLevel[config["agent_log_level"].upper()] # convert to enum
new = cls(settings=cls.Settings(**config))
return new