Move IO to environments from session and add agent logging

This commit is contained in:
Marek Wolan
2024-03-04 18:47:50 +00:00
parent c32bd3f941
commit 2f456e7ae0
8 changed files with 183 additions and 64 deletions

View File

@@ -13,6 +13,7 @@ training_config:
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: true

View File

@@ -15,6 +15,7 @@ training_config:
io_settings:
save_checkpoints: true
checkpoint_interval: 5
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: true
@@ -36,6 +37,11 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
action_space:
@@ -47,24 +53,38 @@ agents:
- node_name: client_2
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_green_user
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
action_space:
@@ -76,10 +96,26 @@ agents:
- node_name: client_1
applications:
- application_name: WebBrowser
- application_name: DatabaseClient
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_applications_per_node: 1
max_applications_per_node: 2
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
2:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 1
reward_function:
reward_components:
- type: DUMMY
@@ -1036,7 +1072,6 @@ agents:
simulation:
network:
nmne_config:
@@ -1046,8 +1081,8 @@ simulation:
nodes:
- ref: router_1
type: router
hostname: router_1
type: router
num_ports: 5
ports:
1:
@@ -1082,18 +1117,18 @@ simulation:
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
type: switch
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
type: switch
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
type: server
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
@@ -1105,8 +1140,8 @@ simulation:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
type: server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
@@ -1122,8 +1157,8 @@ simulation:
- ref: database_server
type: server
hostname: database_server
type: server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
@@ -1137,8 +1172,8 @@ simulation:
type: FTPClient
- ref: backup_server
type: server
hostname: backup_server
type: server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
@@ -1148,8 +1183,8 @@ simulation:
type: FTPServer
- ref: security_suite
type: server
hostname: security_suite
type: server
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
@@ -1160,8 +1195,8 @@ simulation:
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
type: computer
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
@@ -1178,13 +1213,17 @@ simulation:
type: WebBrowser
options:
target_url: http://arcd.com/users/
- ref: client_1_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
type: computer
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
@@ -1201,6 +1240,10 @@ simulation:
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
- ref: client_2_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
services:
- ref: client_2_dns_client
type: DNSClient

View File

@@ -37,7 +37,7 @@ class DataManipulationAgent(AbstractScriptedAgent):
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
return "DONOTHING", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)

View File

@@ -11,7 +11,6 @@ from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAge
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import ProbabilisticAgent
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -210,10 +209,6 @@ class PrimaiteGame:
:return: A PrimaiteGame object.
:rtype: PrimaiteGame
"""
io_settings = cfg.get("io_settings", {})
_ = SessionIO(SessionIOSettings(**io_settings))
# Instantiating this ensures that the game saves to the correct output dir even without being part of a session
game = cls()
game.options = PrimaiteGameOptions(**cfg["game"])
game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False
@@ -415,7 +410,7 @@ class PrimaiteGame:
# CREATE AGENT
if agent_type == "ProbabilisticAgent":
# TODO: implement non-random agents and fix this parsing
settings = agent_cfg.get("agent_settings")
settings = agent_cfg.get("agent_settings", {})
new_agent = ProbabilisticAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,

View File

@@ -8,6 +8,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.io import PrimaiteIO
from primaite.simulator import SIM_OUTPUT
@@ -32,6 +33,9 @@ class PrimaiteGymEnv(gymnasium.Env):
self.episode_counter: int = 0
"""Current episode number."""
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
@property
def agent(self) -> ProxyAgent:
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""
@@ -55,6 +59,10 @@ class PrimaiteGymEnv(gymnasium.Env):
info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(action, state, reward)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
return next_obs, reward, terminated, truncated, info
def _write_step_metadata_json(self, action: int, state: Dict, reward: int):
@@ -79,6 +87,9 @@ class PrimaiteGymEnv(gymnasium.Env):
f"Resetting environment, episode {self.episode_counter}, "
f"avg. reward: {self.agent.reward_function.total_reward}"
)
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
@@ -146,7 +157,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
"""
self.game_config: Dict = env_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config)
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
@@ -164,6 +175,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
super().__init__()
@property
@@ -173,7 +188,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config)
if self.io.settings.save_agent_actions:
self.io.write_agent_actions(episode=self.episode_counter)
self.io.clear_agent_actions()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
state = self.game.get_sim_state()
@@ -196,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.apply_agent_actions()
agent_actions = self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
@@ -215,6 +233,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(actions, state, rewards)
if self.io.settings.save_agent_actions:
self.io.store_agent_actions(
agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter
)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict):

View File

@@ -1,53 +1,50 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict
from primaite import PRIMAITE_PATHS
from primaite import getLogger, PRIMAITE_PATHS
from primaite.simulator import SIM_OUTPUT
class SessionIOSettings(BaseModel):
"""Schema for session IO settings."""
model_config = ConfigDict(extra="forbid")
save_final_model: bool = True
"""Whether to save the final model right at the end of training."""
save_checkpoints: bool = False
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
checkpoint_interval: int = 10
"""How often to save a checkpoint model (if save_checkpoints is True)."""
save_logs: bool = True
"""Whether to save logs"""
save_transactions: bool = True
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_tensorboard_logs: bool = False
"""Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
save_pcap_logs: bool = False
"""Whether to save PCAP logs."""
save_sys_logs: bool = False
"""Whether to save system logs."""
_LOGGER = getLogger(__name__)
class SessionIO:
class PrimaiteIO:
"""
Class for managing session IO.
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
"""
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
self.settings: SessionIOSettings = settings
class Settings(BaseModel):
"""Config schema for PrimaiteIO object."""
model_config = ConfigDict(extra="forbid")
save_logs: bool = True
"""Whether to save logs"""
save_agent_actions: bool = True
"""Whether to save a log of all agents' actions every step."""
save_transactions: bool = True
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_step_metadata: bool = False
"""Whether to save the RL agents' action, environment state, and other data at every single step."""
save_pcap_logs: bool = False
"""Whether to save PCAP logs."""
save_sys_logs: bool = False
"""Whether to save system logs."""
def __init__(self, settings: Optional[Settings] = None) -> None:
self.settings = settings or PrimaiteIO.Settings()
self.session_path: Path = self.generate_session_path()
# set global SIM_OUTPUT path
SIM_OUTPUT.path = self.session_path / "simulation_output"
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
self.agent_action_log: List[Dict] = []
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
# possible refactor needed
@@ -68,3 +65,56 @@ class SessionIO:
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
def generate_agent_actions_save_path(self, episode: int) -> Path:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None:
"""Cache agent actions for a particular step.
:param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected
format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters]
CAOS action is a string representing one the CAOS actions.
parameters is a dict of parameter names and values for that particular CAOS action.
For example:
{
'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}),
'defender': ('DO_NOTHING', {})
}
:type agent_actions: Dict
:param timestep: Simulation timestep when these actions occurred.
:type timestep: int
"""
self.agent_action_log.append(
[
{
"episode": episode,
"timestep": timestep,
"agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()},
}
]
)
def write_agent_actions(self, episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number
:type episode: int
"""
path = self.generate_agent_actions_save_path(episode=episode)
path.parent.mkdir(exist_ok=True, parents=True)
path.touch()
_LOGGER.info(f"Saving agent action log to {path}")
with open(path, "w") as file:
json.dump(self.agent_action_log, fp=file, indent=1)
def clear_agent_actions(self) -> None:
"""Reset the agent action log back to an empty dictionary."""
self.agent_action_log = []
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":
"""Create an instance of PrimaiteIO based on a configuration dict."""
new = cls()
return new

View File

@@ -39,9 +39,9 @@ class SB3Policy(PolicyABC, identifier="SB3"):
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
if self.session.io_manager.settings.save_checkpoints:
if self.session.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
save_freq=timesteps_per_episode * self.session.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)

View File

@@ -1,3 +1,4 @@
# raise DeprecationWarning("This module is deprecated")
from enum import Enum
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
@@ -5,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import SessionIO, SessionIOSettings
from primaite.session.io import PrimaiteIO
# from primaite.game.game import PrimaiteGame
from primaite.session.policy.policy import PolicyABC
@@ -53,12 +54,18 @@ class PrimaiteSession:
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.io_manager: Optional["SessionIO"] = None
self.io_manager: Optional["PrimaiteIO"] = None
"""IO manager for the session."""
self.game_cfg: Dict = game_cfg
"""Primaite Game object for managing main simulation loop and agents."""
self.save_checkpoints: bool = False
"""Whether to save chcekpoints."""
self.checkpoint_interval: int = 10
"""If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes."""
def start_session(self) -> None:
"""Commence the training/eval session."""
print("Starting Primaite Session")
@@ -89,12 +96,13 @@ class PrimaiteSession:
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
io_manager = SessionIO(SessionIOSettings(**io_settings))
io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {}))
sess = cls(game_cfg=cfg)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])
sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints")
sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval")
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":