Merged PR 295: Agent action logging

## Summary
Added a new optional capability to create a JSON log each episode with a list of action each agent took each step (including scripted, RL, and red agents).

Also I had to slightly refactor the IO system to not rely on PrimaiteSession, as it's gonna be deprecated soon. Therefore the IO module is now linked to the gym environment. Each time you init a gym environment, it creates a session directory.

## Test process
Tried the SB3, Ray SARL and Ray MARL notebooks to see that the outputs get generated.

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [x] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [x] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [x] attended to any **TO-DOs** left in the code

Move IO to environments from session and add agent logging

Related work items: #2278
This commit is contained in:
Marek Wolan
2024-03-05 16:39:44 +00:00
10 changed files with 134 additions and 57 deletions

View File

@@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed the order of service health state
- Fixed an issue where starting a node didn't start the services on it
- Added support for SQL INSERT command.
- Added ability to log each agent's action choices in each step to a JSON file.

View File

@@ -18,7 +18,7 @@ This section configures how PrimAITE saves data during simulation and training.
checkpoint_interval: 10
# save_logs: True
# save_transactions: False
# save_tensorboard_logs: False
save_agent_actions: True
save_step_metadata: False
save_pcap_logs: False
save_sys_logs: False
@@ -55,15 +55,13 @@ Defines how often to save the policy during training.
*currently unused*.
``save_transactions``
---------------------
*currently unused*.
``save_agent_actions``
----------------------
``save_tensorboard_logs``
-------------------------
Optional. Default value is ``True``.
*currently unused*.
If ``True``, this will create a JSON file each episode detailing every agent's action in each step of that episode, formatted according to the CAOS format. This includes scripted, RL, and red agents.
``save_step_metadata``
----------------------

View File

@@ -13,6 +13,7 @@ training_config:
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: true

View File

@@ -15,6 +15,7 @@ training_config:
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: true
@@ -35,7 +36,7 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
@@ -1283,7 +1284,6 @@ agents:
simulation:
network:
nmne_config:

View File

@@ -37,7 +37,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
return "DONOTHING", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)

View File

@@ -11,7 +11,6 @@ from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAge
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import ProbabilisticAgent
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -210,10 +209,6 @@ class PrimaiteGame:
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
io_settings = cfg.get("io_settings", {})
_ = SessionIO(SessionIOSettings(**io_settings))
# Instantiating this ensures that the game saves to the correct output dir even without being part of a session
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
@@ -415,7 +410,7 @@ class PrimaiteGame:
# CREATE AGENT
if agent_type == "ProbabilisticAgent":
# TODO: implement non-random agents and fix this parsing
settings = agent_cfg.get("agent_settings")
settings = agent_cfg.get("agent_settings", {})
new_agent = ProbabilisticAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,

View File

@@ -9,6 +9,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite import getLogger
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.io import PrimaiteIO
from primaite.simulator import SIM_OUTPUT
_LOGGER = getLogger(__name__)
@@ -35,6 +36,9 @@ class PrimaiteGymEnv(gymnasium.Env):
self.episode_counter: int = 0
"""Current episode number."""
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
@property
def agent(self) -> ProxyAgent:
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
@@ -58,6 +62,10 @@ class PrimaiteGymEnv(gymnasium.Env):
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
@@ -82,6 +90,9 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
@@ -149,7 +160,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
"""
self.game_config: Dict = env_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
@@ -167,6 +178,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
super().__init__()
@property
@@ -176,7 +191,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
@@ -199,7 +217,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.apply_agent_actions()
agent_actions = self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -218,6 +236,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(actions, state, rewards)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):

View File

@@ -1,55 +1,54 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
from primaite import PRIMAITE_PATHS
from primaite import getLogger, PRIMAITE_PATHS
from primaite.simulator import SIM_OUTPUT
class SessionIOSettings(BaseModel):
"""Schema for session IO settings."""
model_config = ConfigDict(extra="forbid")
save_final_model: bool = True
"""Whether to save the final model right at the end of training."""
save_checkpoints: bool = False
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
checkpoint_interval: int = 10
"""How often to save a checkpoint model (if save_checkpoints is True)."""
save_logs: bool = True
"""Whether to save logs"""
save_transactions: bool = True
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_tensorboard_logs: bool = False
"""Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
save_pcap_logs: bool = False
"""Whether to save PCAP logs."""
save_sys_logs: bool = False
"""Whether to save system logs."""
_LOGGER = getLogger(__name__)
class SessionIO:
class PrimaiteIO:
"""
Class for managing session IO.
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
Currently it's handling path generation, but could expand to handle loading, transaction, and so on.
"""
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
self.settings: SessionIOSettings = settings
class Settings(BaseModel):
"""Config schema for PrimaiteIO object."""
model_config = ConfigDict(extra="forbid")
save_logs: bool = True
"""Whether to save logs"""
save_agent_actions: bool = True
"""Whether to save a log of all agents' actions every step."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
save_pcap_logs: bool = False
"""Whether to save PCAP logs."""
save_sys_logs: bool = False
"""Whether to save system logs."""
def __init__(self, settings: Optional[Settings] = None) -> None:
"""
Init the PrimaiteIO object.
Note: Instantiating this object creates a new directory for outputs, and sets the global SIM_OUTPUT variable.
It is intended that this object is instantiated when a new environment is created.
"""
self.settings = settings or PrimaiteIO.Settings()
self.session_path: Path = self.generate_session_path()
# set global SIM_OUTPUT path
SIM_OUTPUT.path = self.session_path / "simulation_output"
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
# possible refactor needed
self.agent_action_log: List[Dict] = []
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
@@ -68,3 +67,56 @@ class SessionIO:
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
def generate_agent_actions_save_path(self, episode: int) -> Path:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
"""Cache agent actions for a particular step.
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
CAOS action is a string representing one the CAOS actions.
parameters is a dict of parameter names and values for that particular CAOS action.
For example:
{
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
'defender': ('DO_NOTHING', {})
}
:type agent_actions: Dict
:param timestep: Simulation timestep when these actions occurred.
:type timestep: int
"""
self.agent_action_log.append(
[
{
"episode": episode,
"timestep": timestep,
"agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()},
}
]
)
def write_agent_actions(self, episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number
:type episode: int
"""
path = self.generate_agent_actions_save_path(episode=episode)
path.parent.mkdir(exist_ok=True, parents=True)
path.touch()
_LOGGER.info(f"Saving agent action log to {path}")
with open(path, "w") as file:
json.dump(self.agent_action_log, fp=file, indent=1)
def clear_agent_actions(self) -> None:
"""Reset the agent action log back to an empty dictionary."""
self.agent_action_log = []
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":
"""Create an instance of PrimaiteIO based on a configuration dict."""
new = cls()
return new

View File

@@ -39,9 +39,9 @@ class SB3Policy(PolicyABC, identifier="SB3"):
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
if self.session.io_manager.settings.save_checkpoints:
if self.session.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
save_freq=timesteps_per_episode * self.session.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)

View File

@@ -1,3 +1,4 @@
# raise DeprecationWarning("This module is deprecated")
from enum import Enum
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
@@ -5,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.session.io import PrimaiteIO
# from primaite.game.game import PrimaiteGame
from primaite.session.policy.policy import PolicyABC
@@ -53,12 +54,18 @@ class PrimaiteSession:
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.io_manager: Optional["SessionIO"] = None
self.io_manager: Optional["PrimaiteIO"] = None
"""IO manager for the session."""
self.game_cfg: Dict = game_cfg
"""Primaite Game object for managing main simulation loop and agents."""
self.save_checkpoints: bool = False
"""Whether to save checkpoints."""
self.checkpoint_interval: int = 10
"""If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes."""
def start_session(self) -> None:
"""Commence the training/eval session."""
print("Starting Primaite Session")
@@ -89,12 +96,13 @@ class PrimaiteSession:
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
io_manager = SessionIO(SessionIOSettings(**io_settings))
io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {}))
sess = cls(game_cfg=cfg)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])
sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints")
sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval")
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":