Merged PR 295: Agent action logging
## Summary Added a new optional capability to create a JSON log each episode with a list of action each agent took each step (including scripted, RL, and red agents). Also I had to slightly refactor the IO system to not rely on PrimaiteSession, as it's gonna be deprecated soon. Therefore the IO module is now linked to the gym environment. Each time you init a gym environment, it creates a session directory. ## Test process Tried the SB3, Ray SARL and Ray MARL notebooks to see that the outputs get generated. ## Checklist - [x] PR is linked to a **work item** - [x] **acceptance criteria** of linked ticket are met - [x] performed **self-review** of the code - [ ] written **tests** for any new functionality added with this PR - [x] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [x] updated the **change log** - [x] ran **pre-commit** checks for code style - [x] attended to any **TO-DOs** left in the code Move IO to environments from session and add agent logging Related work items: #2278
This commit is contained in:
@@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Fixed the order of service health state
|
||||
- Fixed an issue where starting a node didn't start the services on it
|
||||
- Added support for SQL INSERT command.
|
||||
- Added ability to log each agent's action choices in each step to a JSON file.
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ This section configures how PrimAITE saves data during simulation and training.
|
||||
checkpoint_interval: 10
|
||||
# save_logs: True
|
||||
# save_transactions: False
|
||||
# save_tensorboard_logs: False
|
||||
save_agent_actions: True
|
||||
save_step_metadata: False
|
||||
save_pcap_logs: False
|
||||
save_sys_logs: False
|
||||
@@ -55,15 +55,13 @@ Defines how often to save the policy during training.
|
||||
|
||||
*currently unused*.
|
||||
|
||||
``save_transactions``
|
||||
---------------------
|
||||
|
||||
*currently unused*.
|
||||
``save_agent_actions``
|
||||
----------------------
|
||||
|
||||
``save_tensorboard_logs``
|
||||
-------------------------
|
||||
Optional. Default value is ``True``.
|
||||
|
||||
*currently unused*.
|
||||
If ``True``, this will create a JSON file each episode detailing every agent's action in each step of that episode, formatted according to the CAOS format. This includes scripted, RL, and red agents.
|
||||
|
||||
``save_step_metadata``
|
||||
----------------------
|
||||
|
||||
@@ -13,6 +13,7 @@ training_config:
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: true
|
||||
|
||||
@@ -15,6 +15,7 @@ training_config:
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: true
|
||||
@@ -35,7 +36,7 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.3
|
||||
@@ -1283,7 +1284,6 @@ agents:
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nmne_config:
|
||||
|
||||
@@ -37,7 +37,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
if timestep < self.next_execution_timestep:
|
||||
return "DONOTHING", {"dummy": 0}
|
||||
return "DONOTHING", {}
|
||||
|
||||
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAge
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents import ProbabilisticAgent
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
@@ -210,10 +209,6 @@ class PrimaiteGame:
|
||||
:return: A PrimaiteGame object.
|
||||
:rtype: PrimaiteGame
|
||||
"""
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
_ = SessionIO(SessionIOSettings(**io_settings))
|
||||
# Instantiating this ensures that the game saves to the correct output dir even without being part of a session
|
||||
|
||||
game = cls()
|
||||
game.options = PrimaiteGameOptions(**cfg["game"])
|
||||
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
|
||||
@@ -415,7 +410,7 @@ class PrimaiteGame:
|
||||
# CREATE AGENT
|
||||
if agent_type == "ProbabilisticAgent":
|
||||
# TODO: implement non-random agents and fix this parsing
|
||||
settings = agent_cfg.get("agent_settings")
|
||||
settings = agent_cfg.get("agent_settings", {})
|
||||
new_agent = ProbabilisticAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
|
||||
@@ -9,6 +9,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -35,6 +36,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
|
||||
@property
|
||||
def agent(self) -> ProxyAgent:
|
||||
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
|
||||
@@ -58,6 +62,10 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(action, state, reward)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.store_agent_actions(
|
||||
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
|
||||
)
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
|
||||
@@ -82,6 +90,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
f"Resetting environment, episode {self.episode_counter}, "
|
||||
f"avg. reward: {self.agent.reward_function.total_reward}"
|
||||
)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
@@ -149,7 +160,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""
|
||||
self.game_config: Dict = env_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
@@ -167,6 +178,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
self.action_space = gymnasium.spaces.Dict(
|
||||
{name: agent.action_manager.space for name, agent in self.agents.items()}
|
||||
)
|
||||
|
||||
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@@ -176,7 +191,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
|
||||
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.write_agent_actions(episode=self.episode_counter)
|
||||
self.io.clear_agent_actions()
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
state = self.game.get_sim_state()
|
||||
@@ -199,7 +217,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
# 1. Perform actions
|
||||
for agent_name, action in actions.items():
|
||||
self.agents[agent_name].store_action(action)
|
||||
self.game.apply_agent_actions()
|
||||
agent_actions = self.game.apply_agent_actions()
|
||||
|
||||
# 2. Advance timestep
|
||||
self.game.advance_timestep()
|
||||
@@ -218,6 +236,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
truncateds["__all__"] = self.game.calculate_truncated()
|
||||
if self.game.save_step_metadata:
|
||||
self._write_step_metadata_json(actions, state, rewards)
|
||||
if self.io.settings.save_agent_actions:
|
||||
self.io.store_agent_actions(
|
||||
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
|
||||
)
|
||||
return next_obs, rewards, terminateds, truncateds, infos
|
||||
|
||||
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):
|
||||
|
||||
@@ -1,55 +1,54 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class SessionIOSettings(BaseModel):
|
||||
"""Schema for session IO settings."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
save_final_model: bool = True
|
||||
"""Whether to save the final model right at the end of training."""
|
||||
save_checkpoints: bool = False
|
||||
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
|
||||
checkpoint_interval: int = 10
|
||||
"""How often to save a checkpoint model (if save_checkpoints is True)."""
|
||||
save_logs: bool = True
|
||||
"""Whether to save logs"""
|
||||
save_transactions: bool = True
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder."""
|
||||
save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
save_pcap_logs: bool = False
|
||||
"""Whether to save PCAP logs."""
|
||||
save_sys_logs: bool = False
|
||||
"""Whether to save system logs."""
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class SessionIO:
|
||||
class PrimaiteIO:
|
||||
"""
|
||||
Class for managing session IO.
|
||||
|
||||
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
|
||||
Currently it's handling path generation, but could expand to handle loading, transaction, and so on.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings: SessionIOSettings = settings
|
||||
class Settings(BaseModel):
|
||||
"""Config schema for PrimaiteIO object."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
save_logs: bool = True
|
||||
"""Whether to save logs"""
|
||||
save_agent_actions: bool = True
|
||||
"""Whether to save a log of all agents' actions every step."""
|
||||
save_step_metadata: bool = False
|
||||
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
|
||||
save_pcap_logs: bool = False
|
||||
"""Whether to save PCAP logs."""
|
||||
save_sys_logs: bool = False
|
||||
"""Whether to save system logs."""
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None) -> None:
|
||||
"""
|
||||
Init the PrimaiteIO object.
|
||||
|
||||
Note: Instantiating this object creates a new directory for outputs, and sets the global SIM_OUTPUT variable.
|
||||
It is intended that this object is instantiated when a new environment is created.
|
||||
"""
|
||||
self.settings = settings or PrimaiteIO.Settings()
|
||||
self.session_path: Path = self.generate_session_path()
|
||||
# set global SIM_OUTPUT path
|
||||
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
||||
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
|
||||
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
|
||||
|
||||
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
|
||||
# possible refactor needed
|
||||
self.agent_action_log: List[Dict] = []
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
@@ -68,3 +67,56 @@ class SessionIO:
|
||||
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
|
||||
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
|
||||
|
||||
def generate_agent_actions_save_path(self, episode: int) -> Path:
|
||||
"""Return the path where agent actions will be saved."""
|
||||
return self.session_path / "agent_actions" / f"episode_{episode}.json"
|
||||
|
||||
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
|
||||
"""Cache agent actions for a particular step.
|
||||
|
||||
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
|
||||
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
|
||||
CAOS action is a string representing one the CAOS actions.
|
||||
parameters is a dict of parameter names and values for that particular CAOS action.
|
||||
For example:
|
||||
{
|
||||
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
|
||||
'defender': ('DO_NOTHING', {})
|
||||
}
|
||||
:type agent_actions: Dict
|
||||
:param timestep: Simulation timestep when these actions occurred.
|
||||
:type timestep: int
|
||||
"""
|
||||
self.agent_action_log.append(
|
||||
[
|
||||
{
|
||||
"episode": episode,
|
||||
"timestep": timestep,
|
||||
"agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def write_agent_actions(self, episode: int) -> None:
|
||||
"""Take the contents of the agent action log and write it to a file.
|
||||
|
||||
:param episode: Episode number
|
||||
:type episode: int
|
||||
"""
|
||||
path = self.generate_agent_actions_save_path(episode=episode)
|
||||
path.parent.mkdir(exist_ok=True, parents=True)
|
||||
path.touch()
|
||||
_LOGGER.info(f"Saving agent action log to {path}")
|
||||
with open(path, "w") as file:
|
||||
json.dump(self.agent_action_log, fp=file, indent=1)
|
||||
|
||||
def clear_agent_actions(self) -> None:
|
||||
"""Reset the agent action log back to an empty dictionary."""
|
||||
self.agent_action_log = []
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "PrimaiteIO":
|
||||
"""Create an instance of PrimaiteIO based on a configuration dict."""
|
||||
new = cls()
|
||||
return new
|
||||
|
||||
@@ -39,9 +39,9 @@ class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
if self.session.io_manager.settings.save_checkpoints:
|
||||
if self.session.save_checkpoints:
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
|
||||
save_freq=timesteps_per_episode * self.session.checkpoint_interval,
|
||||
save_path=self.session.io_manager.generate_model_save_path("sb3"),
|
||||
name_prefix="sb3_model",
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# raise DeprecationWarning("This module is deprecated")
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
@@ -5,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.io import SessionIO, SessionIOSettings
|
||||
from primaite.session.io import PrimaiteIO
|
||||
|
||||
# from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
@@ -53,12 +54,18 @@ class PrimaiteSession:
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.io_manager: Optional["SessionIO"] = None
|
||||
self.io_manager: Optional["PrimaiteIO"] = None
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game_cfg: Dict = game_cfg
|
||||
"""Primaite Game object for managing main simulation loop and agents."""
|
||||
|
||||
self.save_checkpoints: bool = False
|
||||
"""Whether to save checkpoints."""
|
||||
|
||||
self.checkpoint_interval: int = 10
|
||||
"""If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training/eval session."""
|
||||
print("Starting Primaite Session")
|
||||
@@ -89,12 +96,13 @@ class PrimaiteSession:
|
||||
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary."""
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
io_manager = SessionIO(SessionIOSettings(**io_settings))
|
||||
io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {}))
|
||||
|
||||
sess = cls(game_cfg=cfg)
|
||||
sess.io_manager = io_manager
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints")
|
||||
sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval")
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
|
||||
Reference in New Issue
Block a user