From 2f456e7ae07660b6f9382f951cb59fe9b066fe31 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 4 Mar 2024 18:47:50 +0000 Subject: [PATCH 1/4] Move IO to environments from session and add agent logging --- .../config/_package_data/example_config.yaml | 1 + .../example_config_2_rl_agents.yaml | 81 ++++++++++--- .../game/agent/data_manipulation_bot.py | 2 +- src/primaite/game/game.py | 7 +- src/primaite/session/environment.py | 28 ++++- src/primaite/session/io.py | 108 +++++++++++++----- src/primaite/session/policy/sb3.py | 4 +- src/primaite/session/session.py | 16 ++- 8 files changed, 183 insertions(+), 64 deletions(-) diff --git a/src/primaite/config/_package_data/example_config.yaml b/src/primaite/config/_package_data/example_config.yaml index 8d1b4293..77296529 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/src/primaite/config/_package_data/example_config.yaml @@ -13,6 +13,7 @@ training_config: io_settings: save_checkpoints: true checkpoint_interval: 5 + save_agent_actions: true save_step_metadata: false save_pcap_logs: false save_sys_logs: true diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml index 260517b9..a5a1d08f 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/example_config_2_rl_agents.yaml @@ -15,6 +15,7 @@ training_config: io_settings: save_checkpoints: true checkpoint_interval: 5 + save_agent_actions: true save_step_metadata: false save_pcap_logs: false save_sys_logs: true @@ -36,6 +37,11 @@ agents: - ref: client_2_green_user team: GREEN type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -47,24 +53,38 @@ agents: - node_name: client_2 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - type: DUMMY - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 - - ref: client_1_green_user team: GREEN type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -76,10 +96,26 @@ agents: - node_name: client_1 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + reward_function: reward_components: - type: DUMMY @@ -1036,7 +1072,6 @@ agents: - simulation: network: nmne_config: @@ -1046,8 +1081,8 @@ simulation: nodes: - ref: router_1 - type: router hostname: router_1 + type: router num_ports: 5 ports: 1: @@ -1082,18 +1117,18 @@ simulation: protocol: ICMP - ref: switch_1 - type: switch hostname: switch_1 + type: switch num_ports: 8 - ref: switch_2 - type: switch hostname: switch_2 + type: switch num_ports: 8 - ref: domain_controller - type: server hostname: domain_controller + type: server ip_address: 192.168.1.10 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1105,8 +1140,8 @@ simulation: arcd.com: 192.168.1.12 # web server - ref: web_server - type: server hostname: web_server + type: server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1122,8 +1157,8 @@ simulation: - ref: database_server - type: server hostname: database_server + type: server ip_address: 192.168.1.14 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1137,8 +1172,8 @@ simulation: type: FTPClient - ref: backup_server - type: server hostname: backup_server + type: server ip_address: 192.168.1.16 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1148,8 +1183,8 @@ simulation: type: FTPServer - ref: security_suite - type: server hostname: security_suite + type: server ip_address: 192.168.1.110 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1160,8 +1195,8 @@ simulation: subnet_mask: 255.255.255.0 - ref: client_1 - type: computer hostname: client_1 + type: computer ip_address: 192.168.10.21 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -1178,13 +1213,17 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_1_dns_client type: DNSClient - ref: client_2 - type: computer hostname: client_2 + type: computer ip_address: 192.168.10.22 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -1201,6 +1240,10 @@ simulation: data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_2_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py index c758c926..16453433 100644 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ b/src/primaite/game/agent/data_manipulation_bot.py @@ -37,7 +37,7 @@ class DataManipulationAgent(AbstractScriptedAgent): :rtype: Tuple[str, Dict] """ if timestep < self.next_execution_timestep: - return "DONOTHING", {"dummy": 0} + return "DONOTHING", {} self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index cd88d832..394a8154 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -11,7 +11,6 @@ from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAge from primaite.game.agent.observations import ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.agent.scripted_agents import ProbabilisticAgent -from primaite.session.io import SessionIO, SessionIOSettings from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -210,10 +209,6 @@ class PrimaiteGame: :return: A PrimaiteGame object. :rtype: PrimaiteGame """ - io_settings = cfg.get("io_settings", {}) - _ = SessionIO(SessionIOSettings(**io_settings)) - # Instantiating this ensures that the game saves to the correct output dir even without being part of a session - game = cls() game.options = PrimaiteGameOptions(**cfg["game"]) game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False @@ -415,7 +410,7 @@ class PrimaiteGame: # CREATE AGENT if agent_type == "ProbabilisticAgent": # TODO: implement non-random agents and fix this parsing - settings = agent_cfg.get("agent_settings") + settings = agent_cfg.get("agent_settings", {}) new_agent = ProbabilisticAgent( agent_name=agent_cfg["ref"], action_space=action_space, diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index d54503a3..72d5ac9c 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -8,6 +8,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.session.io import PrimaiteIO from primaite.simulator import SIM_OUTPUT @@ -32,6 +33,9 @@ class PrimaiteGymEnv(gymnasium.Env): self.episode_counter: int = 0 """Current episode number.""" + self.io = PrimaiteIO.from_config(game_config.get("io_settings", {})) + """Handles IO for the environment. This produces sys logs, agent logs, etc.""" + @property def agent(self) -> ProxyAgent: """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" @@ -55,6 +59,10 @@ class PrimaiteGymEnv(gymnasium.Env): info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) + if self.io.settings.save_agent_actions: + self.io.store_agent_actions( + agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter + ) return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): @@ -79,6 +87,9 @@ class PrimaiteGymEnv(gymnasium.Env): f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {self.agent.reward_function.total_reward}" ) + if self.io.settings.save_agent_actions: + self.io.write_agent_actions(episode=self.episode_counter) + self.io.clear_agent_actions() self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 @@ -146,7 +157,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): """ self.game_config: Dict = env_config """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(self.game_config) + self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) """Reference to the primaite game""" self._agent_ids = list(self.game.rl_agents.keys()) """Agent ids. This is a list of strings of agent names.""" @@ -164,6 +175,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.action_space = gymnasium.spaces.Dict( {name: agent.action_manager.space for name, agent in self.agents.items()} ) + + self.io = PrimaiteIO.from_config(env_config.get("io_settings")) + """Handles IO for the environment. This produces sys logs, agent logs, etc.""" + super().__init__() @property @@ -173,7 +188,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.game_config) + if self.io.settings.save_agent_actions: + self.io.write_agent_actions(episode=self.episode_counter) + self.io.clear_agent_actions() + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 state = self.game.get_sim_state() @@ -196,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - self.game.apply_agent_actions() + agent_actions = self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -215,6 +233,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: self._write_step_metadata_json(actions, state, rewards) + if self.io.settings.save_agent_actions: + self.io.store_agent_actions( + agent_actions=agent_actions, episode=self.episode_counter, timestep=self.game.step_counter + ) return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index b4b740e9..22d9dbeb 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -1,53 +1,50 @@ +import json from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional from pydantic import BaseModel, ConfigDict -from primaite import PRIMAITE_PATHS +from primaite import getLogger, PRIMAITE_PATHS from primaite.simulator import SIM_OUTPUT - -class SessionIOSettings(BaseModel): - """Schema for session IO settings.""" - - model_config = ConfigDict(extra="forbid") - - save_final_model: bool = True - """Whether to save the final model right at the end of training.""" - save_checkpoints: bool = False - """Whether to save a checkpoint model every `checkpoint_interval` episodes""" - checkpoint_interval: int = 10 - """How often to save a checkpoint model (if save_checkpoints is True).""" - save_logs: bool = True - """Whether to save logs""" - save_transactions: bool = True - """Whether to save transactions, If true, the session path will have a transactions folder.""" - save_tensorboard_logs: bool = False - """Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder.""" - save_step_metadata: bool = False - """Whether to save the RL agents' action, environment state, and other data at every single step.""" - save_pcap_logs: bool = False - """Whether to save PCAP logs.""" - save_sys_logs: bool = False - """Whether to save system logs.""" +_LOGGER = getLogger(__name__) -class SessionIO: +class PrimaiteIO: """ Class for managing session IO. Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on. """ - def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: - self.settings: SessionIOSettings = settings + class Settings(BaseModel): + """Config schema for PrimaiteIO object.""" + + model_config = ConfigDict(extra="forbid") + + save_logs: bool = True + """Whether to save logs""" + save_agent_actions: bool = True + """Whether to save a log of all agents' actions every step.""" + save_transactions: bool = True + """Whether to save transactions, If true, the session path will have a transactions folder.""" + save_step_metadata: bool = False + """Whether to save the RL agents' action, environment state, and other data at every single step.""" + save_pcap_logs: bool = False + """Whether to save PCAP logs.""" + save_sys_logs: bool = False + """Whether to save system logs.""" + + def __init__(self, settings: Optional[Settings] = None) -> None: + self.settings = settings or PrimaiteIO.Settings() self.session_path: Path = self.generate_session_path() # set global SIM_OUTPUT path SIM_OUTPUT.path = self.session_path / "simulation_output" SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs + self.agent_action_log: List[Dict] = [] # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's # possible refactor needed @@ -68,3 +65,56 @@ class SessionIO: def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path: """Return the path where the checkpoint model will be saved (excluding filename extension).""" return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt" + + def generate_agent_actions_save_path(self, episode: int) -> Path: + """Return the path where agent actions will be saved.""" + return self.session_path / "agent_actions" / f"episode_{episode}.json" + + def store_agent_actions(self, agent_actions: Dict, episode: int, timestep: int) -> None: + """Cache agent actions for a particular step. + + :param agent_actions: Dictionary describing actions for any agents that acted in this timestep. The expected + format contains agent identifiers as keys. The keys should map to a tuple of [CAOS action, parameters] + CAOS action is a string representing one the CAOS actions. + parameters is a dict of parameter names and values for that particular CAOS action. + For example: + { + 'green1' : ('NODE_APPLICATION_EXECUTE', {'node_id':1, 'application_id':0}), + 'defender': ('DO_NOTHING', {}) + } + :type agent_actions: Dict + :param timestep: Simulation timestep when these actions occurred. + :type timestep: int + """ + self.agent_action_log.append( + [ + { + "episode": episode, + "timestep": timestep, + "agent_actions": {k: {"action": v[0], "parameters": v[1]} for k, v in agent_actions.items()}, + } + ] + ) + + def write_agent_actions(self, episode: int) -> None: + """Take the contents of the agent action log and write it to a file. + + :param episode: Episode number + :type episode: int + """ + path = self.generate_agent_actions_save_path(episode=episode) + path.parent.mkdir(exist_ok=True, parents=True) + path.touch() + _LOGGER.info(f"Saving agent action log to {path}") + with open(path, "w") as file: + json.dump(self.agent_action_log, fp=file, indent=1) + + def clear_agent_actions(self) -> None: + """Reset the agent action log back to an empty dictionary.""" + self.agent_action_log = [] + + @classmethod + def from_config(cls, config: Dict) -> "PrimaiteIO": + """Create an instance of PrimaiteIO based on a configuration dict.""" + new = cls() + return new diff --git a/src/primaite/session/policy/sb3.py b/src/primaite/session/policy/sb3.py index 254baf4d..6220371d 100644 --- a/src/primaite/session/policy/sb3.py +++ b/src/primaite/session/policy/sb3.py @@ -39,9 +39,9 @@ class SB3Policy(PolicyABC, identifier="SB3"): def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" - if self.session.io_manager.settings.save_checkpoints: + if self.session.save_checkpoints: checkpoint_callback = CheckpointCallback( - save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval, + save_freq=timesteps_per_episode * self.session.checkpoint_interval, save_path=self.session.io_manager.generate_model_save_path("sb3"), name_prefix="sb3_model", ) diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index d244f6b0..84dd9b2f 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -1,3 +1,4 @@ +# raise DeprecationWarning("This module is deprecated") from enum import Enum from pathlib import Path from typing import Dict, List, Literal, Optional, Union @@ -5,7 +6,7 @@ from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv -from primaite.session.io import SessionIO, SessionIOSettings +from primaite.session.io import PrimaiteIO # from primaite.game.game import PrimaiteGame from primaite.session.policy.policy import PolicyABC @@ -53,12 +54,18 @@ class PrimaiteSession: self.policy: PolicyABC """The reinforcement learning policy.""" - self.io_manager: Optional["SessionIO"] = None + self.io_manager: Optional["PrimaiteIO"] = None """IO manager for the session.""" self.game_cfg: Dict = game_cfg """Primaite Game object for managing main simulation loop and agents.""" + self.save_checkpoints: bool = False + """Whether to save chcekpoints.""" + + self.checkpoint_interval: int = 10 + """If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes.""" + def start_session(self) -> None: """Commence the training/eval session.""" print("Starting Primaite Session") @@ -89,12 +96,13 @@ class PrimaiteSession: def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary.""" # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... - io_settings = cfg.get("io_settings", {}) - io_manager = SessionIO(SessionIOSettings(**io_settings)) + io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {})) sess = cls(game_cfg=cfg) sess.io_manager = io_manager sess.training_options = TrainingOptions(**cfg["training_config"]) + sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints") + sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval") # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": From c3010ff816882e47809bd4889b2b9c15253c2ce9 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 4 Mar 2024 18:59:03 +0000 Subject: [PATCH 2/4] Update changelog and docs --- CHANGELOG.md | 1 + docs/source/configuration/io_settings.rst | 7 +++++++ 2 files changed, 8 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index d54af980..48998d57 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the order of service health state - Fixed an issue where starting a node didn't start the services on it - Added support for SQL INSERT command. +- Added ability to log each agent's action choices each step to a JSON file. diff --git a/docs/source/configuration/io_settings.rst b/docs/source/configuration/io_settings.rst index 96cc28fe..e5c6d2ce 100644 --- a/docs/source/configuration/io_settings.rst +++ b/docs/source/configuration/io_settings.rst @@ -19,6 +19,7 @@ This section configures how PrimAITE saves data during simulation and training. # save_logs: True # save_transactions: False # save_tensorboard_logs: False + save_agent_actions: True save_step_metadata: False save_pcap_logs: False save_sys_logs: False @@ -65,6 +66,12 @@ Defines how often to save the policy during training. *currently unused*. +``save_agent_actions`` + +Optional. Default value is ``True``. + +If ``True``, this will create a JSON file each episode detailing every agent's action each step in that episode, formatted according to the CAOS format. This includes scripted, RL, and red agents. + ``save_step_metadata`` ---------------------- From a7bfc56b98bd93f8c4043ffae678e72faa34f8f3 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 5 Mar 2024 11:21:49 +0000 Subject: [PATCH 3/4] Apply documentation changes based on PR review. --- CHANGELOG.md | 2 +- docs/source/configuration/io_settings.rst | 12 +----------- src/primaite/session/io.py | 12 +++++++----- src/primaite/session/session.py | 2 +- 4 files changed, 10 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8064a18e..cdf7b5c3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,7 +29,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed the order of service health state - Fixed an issue where starting a node didn't start the services on it - Added support for SQL INSERT command. -- Added ability to log each agent's action choices each step to a JSON file. +- Added ability to log each agent's action choices in each step to a JSON file. diff --git a/docs/source/configuration/io_settings.rst b/docs/source/configuration/io_settings.rst index e5c6d2ce..f9704541 100644 --- a/docs/source/configuration/io_settings.rst +++ b/docs/source/configuration/io_settings.rst @@ -18,7 +18,6 @@ This section configures how PrimAITE saves data during simulation and training. checkpoint_interval: 10 # save_logs: True # save_transactions: False - # save_tensorboard_logs: False save_agent_actions: True save_step_metadata: False save_pcap_logs: False @@ -56,21 +55,12 @@ Defines how often to save the policy during training. *currently unused*. -``save_transactions`` ---------------------- - -*currently unused*. - -``save_tensorboard_logs`` -------------------------- - -*currently unused*. ``save_agent_actions`` Optional. Default value is ``True``. -If ``True``, this will create a JSON file each episode detailing every agent's action each step in that episode, formatted according to the CAOS format. This includes scripted, RL, and red agents. +If ``True``, this will create a JSON file each episode detailing every agent's action in each step of that episode, formatted according to the CAOS format. This includes scripted, RL, and red agents. ``save_step_metadata`` ---------------------- diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index 22d9dbeb..3e21ed16 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -15,7 +15,7 @@ class PrimaiteIO: """ Class for managing session IO. - Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on. + Currently it's handling path generation, but could expand to handle loading, transaction, and so on. """ class Settings(BaseModel): @@ -27,8 +27,6 @@ class PrimaiteIO: """Whether to save logs""" save_agent_actions: bool = True """Whether to save a log of all agents' actions every step.""" - save_transactions: bool = True - """Whether to save transactions, If true, the session path will have a transactions folder.""" save_step_metadata: bool = False """Whether to save the RL agents' action, environment state, and other data at every single step.""" save_pcap_logs: bool = False @@ -37,6 +35,12 @@ class PrimaiteIO: """Whether to save system logs.""" def __init__(self, settings: Optional[Settings] = None) -> None: + """ + Init the PrimaiteIO object. + + Note: Instantiating this object creates a new directory for outputs, and sets the global SIM_OUTPUT variable. + It is intended that this object is instantiated when a new environment is created. + """ self.settings = settings or PrimaiteIO.Settings() self.session_path: Path = self.generate_session_path() # set global SIM_OUTPUT path @@ -45,8 +49,6 @@ class PrimaiteIO: SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs self.agent_action_log: List[Dict] = [] - # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's - # possible refactor needed def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 84dd9b2f..9c935ae3 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -61,7 +61,7 @@ class PrimaiteSession: """Primaite Game object for managing main simulation loop and agents.""" self.save_checkpoints: bool = False - """Whether to save chcekpoints.""" + """Whether to save checkpoints.""" self.checkpoint_interval: int = 10 """If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes.""" From e117f94f43ad52f0064e0ecaf2fe46d606bbc209 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Tue, 5 Mar 2024 15:46:30 +0000 Subject: [PATCH 4/4] Minor doc fix --- docs/source/configuration/io_settings.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/source/configuration/io_settings.rst b/docs/source/configuration/io_settings.rst index f9704541..979dbfae 100644 --- a/docs/source/configuration/io_settings.rst +++ b/docs/source/configuration/io_settings.rst @@ -57,6 +57,7 @@ Defines how often to save the policy during training. ``save_agent_actions`` +---------------------- Optional. Default value is ``True``.