Merged PR 229: Step metadata json output

## Summary
- Added step metadata json file dumps to the environments. Fixed serialization issues in the Switch and ACLRule classes.

## Test process
Nothing outside of a manual inspection for now.

## Checklist
- [ ] PR is linked to a **work item**
- [ ] **acceptance criteria** of linked ticket are met
- [ ] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [ ] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2085
This commit is contained in:
Christopher McCarthy
2023-12-04 10:47:09 +00:00
committed by Marek Wolan
9 changed files with 67 additions and 4 deletions

View File

@@ -31,4 +31,6 @@ Outputs
Running a session creates a session output directory in your user data folder. The filepath looks like this:
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
the saved agent checkpoints, and final model.
the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that
contains the action, reward, and simulation state. These can be found in
``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_<n>/step_metadata/step_<n>.json``

View File

@@ -1,4 +1,5 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import datetime as datetime
import logging
import logging.config
import shutil
@@ -38,6 +39,7 @@ class _PrimaitePaths:
self.app_config_file_path = self.generate_app_config_file_path()
self.app_log_dir_path = self.generate_app_log_dir_path()
self.app_log_file_path = self.generate_app_log_file_path()
self.episode_log_file_path = self.generate_episode_log_file_path()
def _get_dirs_properties(self) -> List[str]:
class_items = self.__class__.__dict__.items()
@@ -105,6 +107,13 @@ class _PrimaitePaths:
"""The PrimAITE app log file path."""
return self.app_log_dir_path / "primaite.log"
def generate_episode_log_file_path(self) -> Path:
"""The PrimAITE app episode step log file path."""
date_string = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
self.episode_log_dir_path = self.app_log_dir_path / date_string
self.episode_log_dir_path.mkdir(exist_ok=True, parents=True)
return self.episode_log_dir_path / "episode.log"
def __repr__(self) -> str:
properties_str = ", ".join([f"{p}='{getattr(self, p)}'" for p in self._get_dirs_properties()])
return f"{self.__class__.__name__}({properties_str})"

View File

@@ -13,6 +13,7 @@ training_config:
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
game:

View File

@@ -9,6 +9,7 @@ training_config:
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_step_metadata: false
game:

View File

@@ -10,6 +10,7 @@ from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -84,6 +85,9 @@ class PrimaiteGame:
self.ref_map_links: Dict[str, str] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
def step(self):
"""
Perform one step of the simulation/agent loop.
@@ -180,8 +184,13 @@ class PrimaiteGame:
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
io_settings = cfg.get("io_settings", {})
_ = SessionIO(SessionIOSettings(**io_settings))
# Instantiating this ensures that the game saves to the correct output dir even without being part of a session
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
# 1. create simulation
sim = game.simulation

View File

@@ -1,3 +1,4 @@
import json
from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple
import gymnasium
@@ -6,6 +7,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator import SIM_OUTPUT
class PrimaiteGymEnv(gymnasium.Env):
@@ -30,6 +32,7 @@ class PrimaiteGymEnv(gymnasium.Env):
self.game.apply_agent_actions()
self.game.advance_timestep()
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
@@ -37,8 +40,26 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {}
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.game.episode_counter,
"step": self.game.step_counter,
"action": int(action),
"reward": int(reward),
"state": state,
}
with open(path, "w") as file:
json.dump(data, file)
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
print(
@@ -162,8 +183,26 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
infos = {}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{self.game.step_counter}.json"
data = {
"episode": self.game.episode_counter,
"step": self.game.step_counter,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
"state": state,
}
with open(path, "w") as file:
json.dump(data, file)
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}

View File

@@ -25,6 +25,8 @@ class SessionIOSettings(BaseModel):
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_tensorboard_logs: bool = False
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
class SessionIO:

View File

@@ -66,9 +66,9 @@ class ACLRule(SimComponent):
state = super().describe_state()
state["action"] = self.action.value
state["protocol"] = self.protocol.value if self.protocol else None
state["src_ip_address"] = self.src_ip_address if self.src_ip_address else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_port"] = self.src_port.value if self.src_port else None
state["dst_ip_address"] = self.dst_ip_address if self.dst_ip_address else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_port"] = self.dst_port.value if self.dst_port else None
return state

View File

@@ -57,7 +57,7 @@ class Switch(Node):
state = super().describe_state()
state["ports"] = {port_num: port.describe_state() for port_num, port in self.switch_ports.items()}
state["num_ports"] = self.num_ports # redundant?
state["mac_address_table"] = {mac: port for mac, port in self.mac_address_table.items()}
state["mac_address_table"] = {mac: port.port_num for mac, port in self.mac_address_table.items()}
return state
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):