From 40686031e6069303f3d540ec4e154a9636efbfb3 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 13 Jun 2023 09:42:54 +0100 Subject: [PATCH] temp commit --- docs/source/config.rst | 132 +++++++----- pyproject.toml | 1 + src/primaite/agents/agent_abc.py | 36 ++++ src/primaite/agents/rllib.py | 177 +++++++++++++++ src/primaite/agents/sb3.py | 28 +++ src/primaite/common/enums.py | 27 +++ src/primaite/common/training_config.py | 95 -------- src/primaite/config/lay_down_config.py | 8 + src/primaite/config/training_config.py | 60 ++++-- src/primaite/primaite_session.py | 216 +++++++++++++++++++ tests/config/legacy/new_training_config.yaml | 19 +- 11 files changed, 626 insertions(+), 173 deletions(-) create mode 100644 src/primaite/agents/agent_abc.py create mode 100644 src/primaite/agents/rllib.py create mode 100644 src/primaite/agents/sb3.py delete mode 100644 src/primaite/common/training_config.py create mode 100644 src/primaite/primaite_session.py diff --git a/docs/source/config.rst b/docs/source/config.rst index 74898ec1..81468f17 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -22,46 +22,64 @@ The environment config file consists of the following attributes: * **agent_identifier** [enum] - This identifies the agent to use for the session. Select from one of the following: + This identifies the agent to use for the session. Select from one of the following: - * GENERIC - Where a user developed agent is to be used - * STABLE_BASELINES3_PPO - Use a SB3 PPO agent - * STABLE_BASELINES3_A2C - use a SB3 A2C agent + * GENERIC - Where a user developed agent is to be used + * STABLE_BASELINES3_PPO - Use a SB3 PPO agent + * STABLE_BASELINES3_A2C - use a SB3 A2C agent + +* **agent_framework** [enum] + + This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following: + + * NONE - Where a user developed agent is to be used + * SB3 - Stable Baselines3 + * RLLIB - Ray RLlib. + +* **red_agent_identifier** + + This identifies the agent to use for the session. Select from one of the following: + + * A2C - Advantage Actor Critic + * PPO - Proximal Policy Optimization + * HARDCODED - A custom built deterministic agent + * RANDOM - A Stochastic random agent + * **action_type** [enum] - Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session + Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session * **num_episodes** [int] - This defines the number of episodes that the agent will train or be evaluated over. + This defines the number of episodes that the agent will train or be evaluated over. * **num_steps** [int] - Determines the number of steps to run in each episode of the session + Determines the number of steps to run in each episode of the session * **time_delay** [int] - The time delay (in milliseconds) to take between each step when running a GENERIC agent session + The time delay (in milliseconds) to take between each step when running a GENERIC agent session * **session_type** [text] - Type of session to be run (TRAINING or EVALUATION) + Type of session to be run (TRAINING, EVALUATION, or BOTH) * **load_agent** [bool] - Determine whether to load an agent from file + Determine whether to load an agent from file * **agent_load_file** [text] - File path and file name of agent if you're loading one in + File path and file name of agent if you're loading one in * **observation_space_high_value** [int] - The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases + The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases **Reward-Based Config Values** @@ -69,95 +87,95 @@ Rewards are calculated based on the difference between the current state and ref * **Generic [all_ok]** [int] - The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) + The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) * **Node Hardware State [off_should_be_on]** [int] - The score to give when the node should be on, but is off + The score to give when the node should be on, but is off * **Node Hardware State [off_should_be_resetting]** [int] - The score to give when the node should be resetting, but is off + The score to give when the node should be resetting, but is off * **Node Hardware State [on_should_be_off]** [int] - The score to give when the node should be off, but is on + The score to give when the node should be off, but is on * **Node Hardware State [on_should_be_resetting]** [int] - The score to give when the node should be resetting, but is on + The score to give when the node should be resetting, but is on * **Node Hardware State [resetting_should_be_on]** [int] - The score to give when the node should be on, but is resetting + The score to give when the node should be on, but is resetting * **Node Hardware State [resetting_should_be_off]** [int] - The score to give when the node should be off, but is resetting + The score to give when the node should be off, but is resetting * **Node Hardware State [resetting]** [int] - The score to give when the node is resetting + The score to give when the node is resetting * **Node Operating System or Service State [good_should_be_patching]** [int] - The score to give when the state should be patching, but is good + The score to give when the state should be patching, but is good * **Node Operating System or Service State [good_should_be_compromised]** [int] - The score to give when the state should be compromised, but is good + The score to give when the state should be compromised, but is good * **Node Operating System or Service State [good_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is good + The score to give when the state should be overwhelmed, but is good * **Node Operating System or Service State [patching_should_be_good]** [int] - The score to give when the state should be good, but is patching + The score to give when the state should be good, but is patching * **Node Operating System or Service State [patching_should_be_compromised]** [int] - The score to give when the state should be compromised, but is patching + The score to give when the state should be compromised, but is patching * **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is patching + The score to give when the state should be overwhelmed, but is patching * **Node Operating System or Service State [patching]** [int] - The score to give when the state is patching + The score to give when the state is patching * **Node Operating System or Service State [compromised_should_be_good]** [int] - The score to give when the state should be good, but is compromised + The score to give when the state should be good, but is compromised * **Node Operating System or Service State [compromised_should_be_patching]** [int] - The score to give when the state should be patching, but is compromised + The score to give when the state should be patching, but is compromised * **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is compromised + The score to give when the state should be overwhelmed, but is compromised * **Node Operating System or Service State [compromised]** [int] - The score to give when the state is compromised + The score to give when the state is compromised * **Node Operating System or Service State [overwhelmed_should_be_good]** [int] - The score to give when the state should be good, but is overwhelmed + The score to give when the state should be good, but is overwhelmed * **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] - The score to give when the state should be patching, but is overwhelmed + The score to give when the state should be patching, but is overwhelmed * **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] - The score to give when the state should be compromised, but is overwhelmed + The score to give when the state should be compromised, but is overwhelmed * **Node Operating System or Service State [overwhelmed]** [int] - The score to give when the state is overwhelmed + The score to give when the state is overwhelmed * **Node File System State [good_should_be_repairing]** [int] @@ -261,37 +279,37 @@ Rewards are calculated based on the difference between the current state and ref * **IER Status [red_ier_running]** [int] - The score to give when a red agent IER is permitted to run + The score to give when a red agent IER is permitted to run * **IER Status [green_ier_blocked]** [int] - The score to give when a green agent IER is prevented from running + The score to give when a green agent IER is prevented from running **Patching / Reset Durations** * **os_patching_duration** [int] - The number of steps to take when patching an Operating System + The number of steps to take when patching an Operating System * **node_reset_duration** [int] - The number of steps to take when resetting a node's hardware state + The number of steps to take when resetting a node's hardware state * **service_patching_duration** [int] - The number of steps to take when patching a service + The number of steps to take when patching a service * **file_system_repairing_limit** [int]: - The number of steps to take when repairing the file system + The number of steps to take when repairing the file system * **file_system_restoring_limit** [int] - The number of steps to take when restoring the file system + The number of steps to take when restoring the file system * **file_system_scanning_limit** [int] - The number of steps to take when scanning the file system + The number of steps to take when scanning the file system The Lay Down Config ******************* @@ -300,22 +318,22 @@ The lay down config file consists of the following attributes: * **itemType: ACTIONS** [enum] - Determines whether a NODE or ACL action space format is adopted for the session + Determines whether a NODE or ACL action space format is adopted for the session * **itemType: OBSERVATION_SPACE** [dict] - Allows for user to configure observation space by combining one or more observation components. List of available - components is is :py:mod:'primaite.environment.observations'. + Allows for user to configure observation space by combining one or more observation components. List of available + components is is :py:mod:'primaite.environment.observations'. - The observation space config item should have a ``components`` key which is a list of components. Each component - config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the - component while it is being initialised. + The observation space config item should have a ``components`` key which is a list of components. Each component + config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the + component while it is being initialised. - This example illustrates the correct format for the observation space config item + This example illustrates the correct format for the observation space config item .. code-block::yaml - - itemType: OBSERVATION_SPACE + - item_type: OBSERVATION_SPACE components: - name: LINK_TRAFFIC_LEVELS options: @@ -328,15 +346,15 @@ The lay down config file consists of the following attributes: * **item_type: PORTS** [int] - Provides a list of ports modelled in this session + Provides a list of ports modelled in this session * **item_type: SERVICES** [freetext] - Provides a list of services modelled in this session + Provides a list of services modelled in this session * **item_type: NODE** - Defines a node included in the system laydown being simulated. It should consist of the following attributes: + Defines a node included in the system laydown being simulated. It should consist of the following attributes: * **id** [int]: Unique ID for this YAML item * **name** [freetext]: Human-readable name of the component @@ -355,7 +373,7 @@ The lay down config file consists of the following attributes: * **item_type: LINK** - Defines a link included in the system laydown being simulated. It should consist of the following attributes: + Defines a link included in the system laydown being simulated. It should consist of the following attributes: * **id** [int]: Unique ID for this YAML item * **name** [freetext]: Human-readable name of the component @@ -365,7 +383,7 @@ The lay down config file consists of the following attributes: * **item_type: GREEN_IER** - Defines a green agent Information Exchange Requirement (IER). It should consist of: + Defines a green agent Information Exchange Requirement (IER). It should consist of: * **id** [int]: Unique ID for this YAML item * **start_step** [int]: The start step (in the episode) for this IER to begin @@ -379,7 +397,7 @@ The lay down config file consists of the following attributes: * **item_type: RED_IER** - Defines a red agent Information Exchange Requirement (IER). It should consist of: + Defines a red agent Information Exchange Requirement (IER). It should consist of: * **id** [int]: Unique ID for this YAML item * **start_step** [int]: The start step (in the episode) for this IER to begin diff --git a/pyproject.toml b/pyproject.toml index 7ddf7710..b562e930 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "numpy==1.23.5", "platformdirs==3.5.1", "PyYAML==6.0", + "ray[rllib]==2.2.0", "stable-baselines3==1.6.2", "typer[all]==0.9.0" ] diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py new file mode 100644 index 00000000..c500128d --- /dev/null +++ b/src/primaite/agents/agent_abc.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from primaite.environment.primaite_env import Primaite + + +class AgentABC(ABC): + + @abstractmethod + def __init__(self, env: Primaite): + self._env: Primaite = env + self._agent = None + + @abstractmethod + def _setup(self): + pass + + @abstractmethod + def learn(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + @abstractmethod + def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + @abstractmethod + def load(self): + pass + + @abstractmethod + def save(self): + pass + + @abstractmethod + def export(self): + pass \ No newline at end of file diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py new file mode 100644 index 00000000..d07265b4 --- /dev/null +++ b/src/primaite/agents/rllib.py @@ -0,0 +1,177 @@ +import glob +import time +from enum import Enum +from pathlib import Path +from typing import Union, Optional + +from ray.rllib.algorithms import Algorithm +from ray.rllib.algorithms.ppo import PPOConfig +from ray.tune.registry import register_env + +from primaite.config import training_config +from primaite.environment.primaite_env import Primaite + + +class DLFramework(Enum): + """The DL Frameworks enumeration.""" + TF = "tf" + TF2 = "tf2" + TORCH = "torch" + + +def env_creator(env_config): + training_config_path = env_config["training_config_path"] + lay_down_config_path = env_config["lay_down_config_path"] + return Primaite(training_config_path, lay_down_config_path, []) + + +def get_ppo_config( + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + framework: Optional[DLFramework] = DLFramework.TORCH +) -> PPOConfig(): + # Register environment + register_env("primaite", env_creator) + + # Setup PPO + config = PPOConfig() + + config_values = training_config.load(training_config_path) + + # Setup our config object to use our environment + config.environment( + env="primaite", + env_config=dict( + training_config_path=training_config_path, + lay_down_config_path=lay_down_config_path + ) + ) + + env_config = config_values + action_type = env_config.action_type + red_agent = env_config.red_agent_identifier + + if red_agent == "RANDOM" and action_type == "NODE": + config.training( + train_batch_size=6000, lr=5e-5 + ) # number of steps in a training iteration + elif red_agent == "RANDOM" and action_type != "NODE": + config.training(train_batch_size=6000, lr=5e-5) + elif red_agent == "CONFIG" and action_type == "NODE": + config.training(train_batch_size=400, lr=5e-5) + elif red_agent == "CONFIG" and action_type != "NONE": + config.training(train_batch_size=500, lr=5e-5) + else: + config.training(train_batch_size=500, lr=5e-5) + + # Decide if you want torch or tensorflow DL framework. Default is "tf" + config.framework(framework=framework.value) + + # Set the log level to DEBUG, INFO, WARN, or ERROR + config.debugging(seed=415, log_level="ERROR") + + # Setup evaluation + # Explicitly set "explore"=False to override default + # config.evaluation( + # evaluation_interval=100, + # evaluation_duration=20, + # # evaluation_duration_unit="timesteps",) #default episodes + # evaluation_config={"explore": False}, + # ) + + # Setup sampling rollout workers + config.rollouts( + num_rollout_workers=4, + num_envs_per_worker=1, + horizon=128, # num parralel workiers + ) # max num steps in an episode + + config.build() # Build config + + return config + + +def train( + num_iterations: int, + config: Optional[PPOConfig] = None, + algo: Optional[Algorithm] = None +): + """ + + Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint) + """ + + start_time = time.time() + + if algo is None: + algo = config.build() + elif config is None: + config = algo.get_config() + + print(f"Algorithm type: {type(algo)}") + + # iterations are not the same as episodes. + for i in range(num_iterations): + result = algo.train() + # # Save every 10 iterations or after last iteration in training + # if (i % 100 == 0) or (i == num_iterations - 1): + print( + f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}") + # save checkpoint file + checkpoint_file = algo.save("./") + print(f"Checkpoint saved at {checkpoint_file}") + + # convert num_iterations to num_episodes + num_episodes = len( + result["hist_stats"]["episode_lengths"]) * num_iterations + # convert num_iterations to num_timesteps + num_timesteps = sum( + result["hist_stats"]["episode_lengths"] * num_iterations) + # calculate number of wins + + # train time + print(f"Training took {time.time() - start_time:.2f} seconds") + print( + f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}") + return result + + +def load_model_from_checkpoint(config, checkpoint=None): + # create an empty Algorithm + algo = config.build() + + if checkpoint is None: + # Get the checkpoint with the highest iteration number + checkpoint = get_most_recent_checkpoint(config) + + # restore the agent from the checkpoint + algo.restore(checkpoint) + + return algo + + +def get_most_recent_checkpoint(config): + """ + Get the most recent checkpoint for specified action type, red agent and algorithm + """ + + env_config = list(config.env_config.values())[0] + action_type = env_config.action_type + red_agent = env_config.red_agent_identifier + algo_name = config.algo_class.__name__ + + # Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model + relevant_checkpoints = glob.glob( + f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*" + ) + checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints] + max_checkpoint = str(max(checkpoint_numbers)) + checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint + checkpoint = ( + relevant_checkpoints[0].split("_")[0] + + "_" + + checkpoint_number_to_use + + "/rllib_checkpoint.json" + ) + + return checkpoint diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py new file mode 100644 index 00000000..cb12210c --- /dev/null +++ b/src/primaite/agents/sb3.py @@ -0,0 +1,28 @@ +# from typing import Optional +# +# from primaite.agents.agent_abc import AgentABC +# from primaite.environment.primaite_env import Primaite +# +# +# class SB3PPO(AgentABC): +# def __init__(self, env: Primaite): +# super().__init__(env) +# +# def _setup(self): +# if self._env.training_config +# pass +# +# def learn(self, time_steps: Optional[int], episodes: Optional[int]): +# pass +# +# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): +# pass +# +# def load(self): +# pass +# +# def save(self): +# pass +# +# def export(self): +# pass \ No newline at end of file diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 68ad80f2..121beb60 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -79,6 +79,33 @@ class Protocol(Enum): NONE = 7 +class SessionType(Enum): + "The type of PrimAITE Session to be run." + TRAINING = 1 + EVALUATION = 2 + BOTH = 3 + + +class VerboseLevel(Enum): + """PrimAITE Session Output verbose level.""" + NO_OUTPUT = 0 + INFO = 1 + DEBUG = 2 + + +class AgentFramework(Enum): + NONE = 0 + SB3 = 1 + RLLIB = 2 + + +class RedAgentIdentifier(Enum): + A2C = 1 + PPO = 2 + HARDCODED = 3 + RANDOM = 4 + + class ActionType(Enum): """Action type enumeration.""" diff --git a/src/primaite/common/training_config.py b/src/primaite/common/training_config.py deleted file mode 100644 index d45bedf9..00000000 --- a/src/primaite/common/training_config.py +++ /dev/null @@ -1,95 +0,0 @@ -# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -# """The config class.""" -# from dataclasses import dataclass -# -# from primaite.common.enums import ActionType -# -# -# @dataclass() -# class TrainingConfig: -# """Class to hold main config values.""" -# -# # Generic -# agent_identifier: str # The Red Agent algo/class to be used -# action_type: ActionType # type of action to use (NODE/ACL/ANY) -# num_episodes: int # number of episodes to train over -# num_steps: int # number of steps in an episode -# time_delay: int # delay between steps (ms) - applies to generic agents only -# # file -# session_type: str # the session type to run (TRAINING or EVALUATION) -# load_agent: str # Determine whether to load an agent from file -# agent_load_file: str # File path and file name of agent if you're loading one in -# -# # Environment -# observation_space_high_value: int # The high value for the observation space -# -# # Reward values -# # Generic -# all_ok: int -# # Node Hardware State -# off_should_be_on: int -# off_should_be_resetting: int -# on_should_be_off: int -# on_should_be_resetting: int -# resetting_should_be_on: int -# resetting_should_be_off: int -# resetting: int -# # Node Software or Service State -# good_should_be_patching: int -# good_should_be_compromised: int -# good_should_be_overwhelmed: int -# patching_should_be_good: int -# patching_should_be_compromised: int -# patching_should_be_overwhelmed: int -# patching: int -# compromised_should_be_good: int -# compromised_should_be_patching: int -# compromised_should_be_overwhelmed: int -# compromised: int -# overwhelmed_should_be_good: int -# overwhelmed_should_be_patching: int -# overwhelmed_should_be_compromised: int -# overwhelmed: int -# # Node File System State -# good_should_be_repairing: int -# good_should_be_restoring: int -# good_should_be_corrupt: int -# good_should_be_destroyed: int -# repairing_should_be_good: int -# repairing_should_be_restoring: int -# repairing_should_be_corrupt: int -# repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore -# -# repairing: int -# restoring_should_be_good: int -# restoring_should_be_repairing: int -# restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption) -# -# restoring_should_be_destroyed: int -# restoring: int -# corrupt_should_be_good: int -# corrupt_should_be_repairing: int -# corrupt_should_be_restoring: int -# corrupt_should_be_destroyed: int -# corrupt: int -# destroyed_should_be_good: int -# destroyed_should_be_repairing: int -# destroyed_should_be_restoring: int -# destroyed_should_be_corrupt: int -# destroyed: int -# scanning: int -# # IER status -# red_ier_running: int -# green_ier_blocked: int -# -# # Patching / Reset -# os_patching_duration: int # The time taken to patch the OS -# node_reset_duration: int # The time taken to reset a node (hardware) -# node_booting_duration = 0 # The Time taken to turn on the node -# node_shutdown_duration = 0 # The time taken to turn off the node -# service_patching_duration: int # The time taken to patch a service -# file_system_repairing_limit: int # The time take to repair a file -# file_system_restoring_limit: int # The time take to restore a file -# file_system_scanning_limit: int # The time taken to scan the file system -# # Patching / Reset -# diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 46389297..4fd2142e 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -2,6 +2,8 @@ from pathlib import Path from typing import Final +import networkx + from primaite import USERS_CONFIG_DIR, getLogger _LOGGER = getLogger(__name__) @@ -9,6 +11,12 @@ _LOGGER = getLogger(__name__) _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" +# class LayDownConfig: +# network: networkx.Graph +# POL +# EIR +# ACL + def ddos_basic_one_config_path() -> Path: """ The path to the example lay_down_config_1_DDOS_basic.yaml file. diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4af36abe..b0956d42 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,4 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +from __future__ import annotations + from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Final, Union, Optional @@ -6,7 +8,8 @@ from typing import Any, Dict, Final, Union, Optional import yaml from primaite import USERS_CONFIG_DIR, getLogger -from primaite.common.enums import ActionType +from primaite.common.enums import ActionType, RedAgentIdentifier, \ + AgentFramework, SessionType _LOGGER = getLogger(__name__) @@ -16,10 +19,11 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training @dataclass() class TrainingConfig: """The Training Config class.""" + agent_framework: AgentFramework = AgentFramework.SB3 + "The agent framework." - # Generic - agent_identifier: str = "STABLE_BASELINES3_A2C" - "The Red Agent algo/class to be used." + red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO + "The red agent/algo class." action_type: ActionType = ActionType.ANY "The ActionType to use." @@ -38,8 +42,8 @@ class TrainingConfig: "The delay between steps (ms). Applies to generic agents only." # file - session_type: str = "TRAINING" - "the session type to run (TRAINING or EVALUATION)" + session_type: SessionType = SessionType.TRAINING + "The type of PrimAITE session to run." load_agent: str = False "Determine whether to load an agent from file." @@ -137,6 +141,24 @@ class TrainingConfig: file_system_scanning_limit: int = 5 "The time taken to scan the file system." + @classmethod + def from_dict( + cls, + config_dict: Dict[str, Union[str, int, bool]] + ) -> TrainingConfig: + field_enum_map = { + "agent_framework": AgentFramework, + "red_agent_identifier": RedAgentIdentifier, + "action_type": ActionType, + "session_type": SessionType + } + + for field, enum_class in field_enum_map.items(): + if field in config_dict: + config_dict[field] = enum_class[field] + + return TrainingConfig(**config_dict) + def to_dict(self, json_serializable: bool = True): """ Serialise the ``TrainingConfig`` as dict. @@ -196,10 +218,8 @@ def load(file_path: Union[str, Path], f"from legacy format. Attempting to use file as is." ) _LOGGER.error(msg) - # Convert values to Enums - config["action_type"] = ActionType[config["action_type"]] try: - return TrainingConfig(**config) + return TrainingConfig.from_dict(**config) except TypeError as e: msg = ( f"Error when creating an instance of {TrainingConfig} " @@ -214,22 +234,30 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + agent_framework: AgentFramework = AgentFramework.SB3, + red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO, + action_type: ActionType = ActionType.ANY, + num_steps: int = 256 ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. :param legacy_config_dict: A legacy training config dict. - :param num_steps: The number of steps to set as legacy training configs - don't have num_steps values. + :param agent_framework: The agent framework to use as legacy training + configs don't have agent_framework values. + :param red_agent_identifier: The red agent identifier to use as legacy + training configs don't have red_agent_identifier values. :param action_type: The action space type to set as legacy training configs don't have action_type values. + :param num_steps: The number of steps to set as legacy training configs + don't have num_steps values. :return: The converted training config dict. """ config_dict = { - "num_steps": num_steps, - "action_type": action_type + "agent_framework": agent_framework.name, + "red_agent_identifier": red_agent_identifier.name, + "action_type": action_type.name, + "num_steps": num_steps } for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) @@ -246,7 +274,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str: :return: The mapped key. """ key_mapping = { - "agentIdentifier": "agent_identifier", + "agentIdentifier": None, "numEpisodes": "num_episodes", "timeDelay": "time_delay", "configFilename": None, diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py new file mode 100644 index 00000000..3957e822 --- /dev/null +++ b/src/primaite/primaite_session.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import json +from datetime import datetime +from pathlib import Path +from typing import Final, Optional, Union +from uuid import uuid4 + +from primaite import getLogger, SESSIONS_DIR +from primaite.config.training_config import TrainingConfig +from primaite.environment.primaite_env import Primaite + +_LOGGER = getLogger(__name__) + + +def _get_session_path(session_timestamp: datetime) -> Path: + """ + Get the directory path the session will output to. + + This is set in the format of: + ~/primaite/sessions//_. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = SESSIONS_DIR / date_dir / session_dir + session_path.mkdir(exist_ok=True, parents=True) + _LOGGER.debug(f"Created PrimAITE Session path: {session_path}") + + return session_path + + +class PrimaiteSession: + + def __init__( + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + auto: bool = True + ): + if not isinstance(training_config_path, Path): + training_config_path = Path(training_config_path) + self._training_config_path: Final[Union[Path]] = training_config_path + + if not isinstance(lay_down_config_path, Path): + lay_down_config_path = Path(lay_down_config_path) + self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path + + self._auto: Final[bool] = auto + + self._uuid: str = str(uuid4()) + self._session_timestamp: Final[datetime] = datetime.now() + self._session_path: Final[Path] = _get_session_path( + self._session_timestamp + ) + self._timestamp_str: Final[str] = self._session_timestamp.strftime( + "%Y-%m-%d_%H-%M-%S") + self._metadata_path = self._session_path / "session_metadata.json" + + + self._env = None + self._training_config = None + self._can_learn: bool = False + _LOGGER.debug("") + + if self._auto: + self.setup() + self.learn() + + @property + def uuid(self): + """The session UUID.""" + return self._uuid + + def _setup_primaite_env(self, transaction_list: Optional[list] = None): + if not transaction_list: + transaction_list = [] + self._env: Primaite = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + transaction_list=transaction_list, + session_path=self._session_path, + timestamp_str=self._timestamp_str + ) + self._training_config: TrainingConfig = self._env.training_config + + def _write_session_metadata_file(self): + """ + Write the ``session_metadata.json`` file. + + Creates a ``session_metadata.json`` in the ``session_dir`` directory + and adds the following key/value pairs: + + - uuid: The UUID assigned to the session upon instantiation. + - start_datetime: The date & time the session started in iso format. + - end_datetime: NULL. + - total_episodes: NULL. + - total_time_steps: NULL. + - env: + - training_config: + - All training config items + - lay_down_config: + - All lay down config items + """ + metadata_dict = { + "uuid": self._uuid, + "start_datetime": self._session_timestamp.isoformat(), + "end_datetime": None, + "total_episodes": None, + "total_time_steps": None, + "env": { + "training_config": self._env.training_config.to_dict( + json_serializable=True + ), + "lay_down_config": self._env.lay_down_config, + }, + } + _LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}") + with open(self._metadata_path, "w") as file: + json.dump(metadata_dict, file) + + def _update_session_metadata_file(self): + """ + Update the ``session_metadata.json`` file. + + Updates the `session_metadata.json`` in the ``session_dir`` directory + with the following key/value pairs: + + - end_datetime: NULL. + - total_episodes: NULL. + - total_time_steps: NULL. + """ + with open(self._metadata_path, "r") as file: + metadata_dict = json.load(file) + + metadata_dict["end_datetime"] = datetime.now().isoformat() + metadata_dict["total_episodes"] = self._env.episode_count + metadata_dict["total_time_steps"] = self._env.total_step_count + + _LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}") + with open(self._metadata_path, "w") as file: + json.dump(metadata_dict, file) + + def setup(self): + self._setup_primaite_env() + self._can_learn = True + pass + + def learn( + self, + time_steps: Optional[int], + episodes: Optional[int], + iterations: Optional[int], + **kwargs + ): + if self._can_learn: + # Run environment against an agent + if self._training_config.agent_identifier == "GENERIC": + run_generic(env=env, config_values=config_values) + elif self._training_config == "STABLE_BASELINES3_PPO": + run_stable_baselines3_ppo( + env=env, + config_values=config_values, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + elif self._training_config == "STABLE_BASELINES3_A2C": + run_stable_baselines3_a2c( + env=env, + config_values=config_values, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + + print("Session finished") + _LOGGER.debug("Session finished") + + print("Saving transaction logs...") + write_transaction_to_file( + transaction_list=transaction_list, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + + print("Updating Session Metadata file...") + _update_session_metadata_file(session_dir=session_dir, env=env) + + print("Finished") + _LOGGER.debug("Finished") + + def evaluate( + self, + time_steps: Optional[int], + episodes: Optional[int], + **kwargs + ): + pass + + def export(self): + pass + + @classmethod + def import_agent( + cls, + gent_path: str, + training_config_path: str, + lay_down_config_path: str + ) -> PrimaiteSession: + session = PrimaiteSession(training_config_path, lay_down_config_path) + + # Reset the UUID + session._uuid = "" + + return session diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy/new_training_config.yaml index becc1799..44897bfa 100644 --- a/tests/config/legacy/new_training_config.yaml +++ b/tests/config/legacy/new_training_config.yaml @@ -1,11 +1,20 @@ # Main Config File # Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C + +# Sets which agent algorithm framework will be used: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray[RLlib]) +# "NONE" (Custom Agent) +agent_framework: RLLIB + +# Sets which Red Agent algo/class will be used: +# "PPO" (Proximal Policy Optimization) +# "A2C" (Advantage Actor Critic) +# "HARDCODED" (Custom Agent) +# "RANDOM" (Random Action) +red_agent_identifier: PPO + # Sets How the Action Space is defined: # "NODE" # "ACL"