Merged PR 229: Step metadata json output
## Summary - Added step metadata json file dumps to the environments. Fixed serialization issues in the Switch and ACLRule classes. ## Test process Nothing outside of a manual inspection for now. ## Checklist - [ ] PR is linked to a **work item** - [ ] **acceptance criteria** of linked ticket are met - [ ] performed **self-review** of the code - [ ] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [ ] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2085
This commit is contained in:
@@ -31,4 +31,6 @@ Outputs
|
||||
|
||||
Running a session creates a session output directory in your user data folder. The filepath looks like this:
|
||||
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
|
||||
the saved agent checkpoints, and final model.
|
||||
the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that
|
||||
contains the action, reward, and simulation state. These can be found in
|
||||
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_<n>/step_metadata/step_<n>.json``
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import datetime as datetime
|
||||
import logging
|
||||
import logging.config
|
||||
import shutil
|
||||
@@ -38,6 +39,7 @@ class _PrimaitePaths:
|
||||
self.app_config_file_path = self.generate_app_config_file_path()
|
||||
self.app_log_dir_path = self.generate_app_log_dir_path()
|
||||
self.app_log_file_path = self.generate_app_log_file_path()
|
||||
self.episode_log_file_path = self.generate_episode_log_file_path()
|
||||
|
||||
def _get_dirs_properties(self) -> List[str]:
|
||||
class_items = self.__class__.__dict__.items()
|
||||
@@ -105,6 +107,13 @@ class _PrimaitePaths:
|
||||
"""The PrimAITE app log file path."""
|
||||
return self.app_log_dir_path / "primaite.log"
|
||||
|
||||
def generate_episode_log_file_path(self) -> Path:
|
||||
"""The PrimAITE app episode step log file path."""
|
||||
date_string = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
self.episode_log_dir_path = self.app_log_dir_path / date_string
|
||||
self.episode_log_dir_path.mkdir(exist_ok=True, parents=True)
|
||||
return self.episode_log_dir_path / "episode.log"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()])
|
||||
return f"{self.__class__.__name__}({properties_str})"
|
||||
|
||||
@@ -13,6 +13,7 @@ training_config:
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_step_metadata: false
|
||||
|
||||
|
||||
game:
|
||||
|
||||
@@ -9,6 +9,7 @@ training_config:
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_step_metadata: false
|
||||
|
||||
|
||||
game:
|
||||
|
||||
@@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
@@ -84,6 +85,9 @@ class PrimaiteGame:
|
||||
self.ref_map_links: Dict[str, str] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
|
||||
self.save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
Perform one step of the simulation/agent loop.
|
||||
@@ -180,8 +184,13 @@ class PrimaiteGame:
|
||||
:return: A PrimaiteGame object.
|
||||
:rtype: PrimaiteGame
|
||||
"""
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
_ = SessionIO(SessionIOSettings(**io_settings))
|
||||
# Instantiating this ensures that the game saves to the correct output dir even without being part of a session
|
||||
|
||||
game = cls()
|
||||
game.options = PrimaiteGameOptions(**cfg["game"])
|
||||
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
|
||||
|
||||
# 1. create simulation
|
||||
sim = game.simulation
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
|
||||
|
||||
import gymnasium
|
||||
@@ -6,6 +7,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
@@ -30,6 +32,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.game.apply_agent_actions()
|
||||
self.game.advance_timestep()
|
||||
state = self.game.get_sim_state()
|
||||
|
||||
self.game.update_agents(state)
|
||||
|
||||
next_obs = self._get_obs()
|
||||
@@ -37,8 +40,26 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
terminated = False
|
||||
truncated = self.game.calculate_truncated()
|
||||
info = {}
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.game.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"action": int(action),
|
||||
"reward": int(reward),
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
print(
|
||||
@@ -162,8 +183,26 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
infos = {}
|
||||
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
path = output_dir / f"step_{self.game.step_counter}.json"
|
||||
|
||||
data = {
|
||||
"episode": self.game.episode_counter,
|
||||
"step": self.game.step_counter,
|
||||
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
|
||||
"reward": rewards,
|
||||
"state": state,
|
||||
}
|
||||
with open(path, "w") as file:
|
||||
json.dump(data, file)
|
||||
|
||||
def _get_obs(self) -> Dict[str, ObsType]:
|
||||
"""Return the current observation."""
|
||||
obs = {}
|
||||
|
||||
@@ -25,6 +25,8 @@ class SessionIOSettings(BaseModel):
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
||||
save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
|
||||
|
||||
class SessionIO:
|
||||
|
||||
@@ -66,9 +66,9 @@ class ACLRule(SimComponent):
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ class Switch(Node):
|
||||
state = super().describe_state()
|
||||
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
|
||||
state["num_ports"] = self.num_ports # redundant?
|
||||
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
|
||||
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
|
||||
return state
|
||||
|
||||
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
|
||||
|
||||
Reference in New Issue
Block a user