From eb3368edd6d77ea457eeb5abff4590ac780facfe Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 13 Jun 2023 09:42:54 +0100 Subject: [PATCH 01/43] temp commit --- docs/source/config.rst | 132 +++++++----- pyproject.toml | 1 + src/primaite/agents/agent_abc.py | 36 ++++ src/primaite/agents/rllib.py | 177 +++++++++++++++ src/primaite/agents/sb3.py | 28 +++ src/primaite/common/enums.py | 27 +++ src/primaite/common/training_config.py | 95 -------- src/primaite/config/lay_down_config.py | 8 + src/primaite/config/training_config.py | 60 ++++-- src/primaite/primaite_session.py | 216 +++++++++++++++++++ tests/config/legacy/new_training_config.yaml | 19 +- 11 files changed, 626 insertions(+), 173 deletions(-) create mode 100644 src/primaite/agents/agent_abc.py create mode 100644 src/primaite/agents/rllib.py create mode 100644 src/primaite/agents/sb3.py delete mode 100644 src/primaite/common/training_config.py create mode 100644 src/primaite/primaite_session.py diff --git a/docs/source/config.rst b/docs/source/config.rst index 74898ec1..81468f17 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -22,46 +22,64 @@ The environment config file consists of the following attributes: * **agent_identifier** [enum] - This identifies the agent to use for the session. Select from one of the following: + This identifies the agent to use for the session. Select from one of the following: - * GENERIC - Where a user developed agent is to be used - * STABLE_BASELINES3_PPO - Use a SB3 PPO agent - * STABLE_BASELINES3_A2C - use a SB3 A2C agent + * GENERIC - Where a user developed agent is to be used + * STABLE_BASELINES3_PPO - Use a SB3 PPO agent + * STABLE_BASELINES3_A2C - use a SB3 A2C agent + +* **agent_framework** [enum] + + This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following: + + * NONE - Where a user developed agent is to be used + * SB3 - Stable Baselines3 + * RLLIB - Ray RLlib. + +* **red_agent_identifier** + + This identifies the agent to use for the session. Select from one of the following: + + * A2C - Advantage Actor Critic + * PPO - Proximal Policy Optimization + * HARDCODED - A custom built deterministic agent + * RANDOM - A Stochastic random agent + * **action_type** [enum] - Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session + Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session * **num_episodes** [int] - This defines the number of episodes that the agent will train or be evaluated over. + This defines the number of episodes that the agent will train or be evaluated over. * **num_steps** [int] - Determines the number of steps to run in each episode of the session + Determines the number of steps to run in each episode of the session * **time_delay** [int] - The time delay (in milliseconds) to take between each step when running a GENERIC agent session + The time delay (in milliseconds) to take between each step when running a GENERIC agent session * **session_type** [text] - Type of session to be run (TRAINING or EVALUATION) + Type of session to be run (TRAINING, EVALUATION, or BOTH) * **load_agent** [bool] - Determine whether to load an agent from file + Determine whether to load an agent from file * **agent_load_file** [text] - File path and file name of agent if you're loading one in + File path and file name of agent if you're loading one in * **observation_space_high_value** [int] - The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases + The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases **Reward-Based Config Values** @@ -69,95 +87,95 @@ Rewards are calculated based on the difference between the current state and ref * **Generic [all_ok]** [int] - The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) + The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken) * **Node Hardware State [off_should_be_on]** [int] - The score to give when the node should be on, but is off + The score to give when the node should be on, but is off * **Node Hardware State [off_should_be_resetting]** [int] - The score to give when the node should be resetting, but is off + The score to give when the node should be resetting, but is off * **Node Hardware State [on_should_be_off]** [int] - The score to give when the node should be off, but is on + The score to give when the node should be off, but is on * **Node Hardware State [on_should_be_resetting]** [int] - The score to give when the node should be resetting, but is on + The score to give when the node should be resetting, but is on * **Node Hardware State [resetting_should_be_on]** [int] - The score to give when the node should be on, but is resetting + The score to give when the node should be on, but is resetting * **Node Hardware State [resetting_should_be_off]** [int] - The score to give when the node should be off, but is resetting + The score to give when the node should be off, but is resetting * **Node Hardware State [resetting]** [int] - The score to give when the node is resetting + The score to give when the node is resetting * **Node Operating System or Service State [good_should_be_patching]** [int] - The score to give when the state should be patching, but is good + The score to give when the state should be patching, but is good * **Node Operating System or Service State [good_should_be_compromised]** [int] - The score to give when the state should be compromised, but is good + The score to give when the state should be compromised, but is good * **Node Operating System or Service State [good_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is good + The score to give when the state should be overwhelmed, but is good * **Node Operating System or Service State [patching_should_be_good]** [int] - The score to give when the state should be good, but is patching + The score to give when the state should be good, but is patching * **Node Operating System or Service State [patching_should_be_compromised]** [int] - The score to give when the state should be compromised, but is patching + The score to give when the state should be compromised, but is patching * **Node Operating System or Service State [patching_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is patching + The score to give when the state should be overwhelmed, but is patching * **Node Operating System or Service State [patching]** [int] - The score to give when the state is patching + The score to give when the state is patching * **Node Operating System or Service State [compromised_should_be_good]** [int] - The score to give when the state should be good, but is compromised + The score to give when the state should be good, but is compromised * **Node Operating System or Service State [compromised_should_be_patching]** [int] - The score to give when the state should be patching, but is compromised + The score to give when the state should be patching, but is compromised * **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int] - The score to give when the state should be overwhelmed, but is compromised + The score to give when the state should be overwhelmed, but is compromised * **Node Operating System or Service State [compromised]** [int] - The score to give when the state is compromised + The score to give when the state is compromised * **Node Operating System or Service State [overwhelmed_should_be_good]** [int] - The score to give when the state should be good, but is overwhelmed + The score to give when the state should be good, but is overwhelmed * **Node Operating System or Service State [overwhelmed_should_be_patching]** [int] - The score to give when the state should be patching, but is overwhelmed + The score to give when the state should be patching, but is overwhelmed * **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int] - The score to give when the state should be compromised, but is overwhelmed + The score to give when the state should be compromised, but is overwhelmed * **Node Operating System or Service State [overwhelmed]** [int] - The score to give when the state is overwhelmed + The score to give when the state is overwhelmed * **Node File System State [good_should_be_repairing]** [int] @@ -261,37 +279,37 @@ Rewards are calculated based on the difference between the current state and ref * **IER Status [red_ier_running]** [int] - The score to give when a red agent IER is permitted to run + The score to give when a red agent IER is permitted to run * **IER Status [green_ier_blocked]** [int] - The score to give when a green agent IER is prevented from running + The score to give when a green agent IER is prevented from running **Patching / Reset Durations** * **os_patching_duration** [int] - The number of steps to take when patching an Operating System + The number of steps to take when patching an Operating System * **node_reset_duration** [int] - The number of steps to take when resetting a node's hardware state + The number of steps to take when resetting a node's hardware state * **service_patching_duration** [int] - The number of steps to take when patching a service + The number of steps to take when patching a service * **file_system_repairing_limit** [int]: - The number of steps to take when repairing the file system + The number of steps to take when repairing the file system * **file_system_restoring_limit** [int] - The number of steps to take when restoring the file system + The number of steps to take when restoring the file system * **file_system_scanning_limit** [int] - The number of steps to take when scanning the file system + The number of steps to take when scanning the file system The Lay Down Config ******************* @@ -300,22 +318,22 @@ The lay down config file consists of the following attributes: * **itemType: ACTIONS** [enum] - Determines whether a NODE or ACL action space format is adopted for the session + Determines whether a NODE or ACL action space format is adopted for the session * **itemType: OBSERVATION_SPACE** [dict] - Allows for user to configure observation space by combining one or more observation components. List of available - components is is :py:mod:'primaite.environment.observations'. + Allows for user to configure observation space by combining one or more observation components. List of available + components is is :py:mod:'primaite.environment.observations'. - The observation space config item should have a ``components`` key which is a list of components. Each component - config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the - component while it is being initialised. + The observation space config item should have a ``components`` key which is a list of components. Each component + config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the + component while it is being initialised. - This example illustrates the correct format for the observation space config item + This example illustrates the correct format for the observation space config item .. code-block::yaml - - itemType: OBSERVATION_SPACE + - item_type: OBSERVATION_SPACE components: - name: LINK_TRAFFIC_LEVELS options: @@ -328,15 +346,15 @@ The lay down config file consists of the following attributes: * **item_type: PORTS** [int] - Provides a list of ports modelled in this session + Provides a list of ports modelled in this session * **item_type: SERVICES** [freetext] - Provides a list of services modelled in this session + Provides a list of services modelled in this session * **item_type: NODE** - Defines a node included in the system laydown being simulated. It should consist of the following attributes: + Defines a node included in the system laydown being simulated. It should consist of the following attributes: * **id** [int]: Unique ID for this YAML item * **name** [freetext]: Human-readable name of the component @@ -355,7 +373,7 @@ The lay down config file consists of the following attributes: * **item_type: LINK** - Defines a link included in the system laydown being simulated. It should consist of the following attributes: + Defines a link included in the system laydown being simulated. It should consist of the following attributes: * **id** [int]: Unique ID for this YAML item * **name** [freetext]: Human-readable name of the component @@ -365,7 +383,7 @@ The lay down config file consists of the following attributes: * **item_type: GREEN_IER** - Defines a green agent Information Exchange Requirement (IER). It should consist of: + Defines a green agent Information Exchange Requirement (IER). It should consist of: * **id** [int]: Unique ID for this YAML item * **start_step** [int]: The start step (in the episode) for this IER to begin @@ -379,7 +397,7 @@ The lay down config file consists of the following attributes: * **item_type: RED_IER** - Defines a red agent Information Exchange Requirement (IER). It should consist of: + Defines a red agent Information Exchange Requirement (IER). It should consist of: * **id** [int]: Unique ID for this YAML item * **start_step** [int]: The start step (in the episode) for this IER to begin diff --git a/pyproject.toml b/pyproject.toml index 58f63efa..aa9f5fdc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "numpy==1.23.5", "platformdirs==3.5.1", "PyYAML==6.0", + "ray[rllib]==2.2.0", "stable-baselines3==1.6.2", "typer[all]==0.9.0" ] diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py new file mode 100644 index 00000000..c500128d --- /dev/null +++ b/src/primaite/agents/agent_abc.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from primaite.environment.primaite_env import Primaite + + +class AgentABC(ABC): + + @abstractmethod + def __init__(self, env: Primaite): + self._env: Primaite = env + self._agent = None + + @abstractmethod + def _setup(self): + pass + + @abstractmethod + def learn(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + @abstractmethod + def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + @abstractmethod + def load(self): + pass + + @abstractmethod + def save(self): + pass + + @abstractmethod + def export(self): + pass \ No newline at end of file diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py new file mode 100644 index 00000000..d07265b4 --- /dev/null +++ b/src/primaite/agents/rllib.py @@ -0,0 +1,177 @@ +import glob +import time +from enum import Enum +from pathlib import Path +from typing import Union, Optional + +from ray.rllib.algorithms import Algorithm +from ray.rllib.algorithms.ppo import PPOConfig +from ray.tune.registry import register_env + +from primaite.config import training_config +from primaite.environment.primaite_env import Primaite + + +class DLFramework(Enum): + """The DL Frameworks enumeration.""" + TF = "tf" + TF2 = "tf2" + TORCH = "torch" + + +def env_creator(env_config): + training_config_path = env_config["training_config_path"] + lay_down_config_path = env_config["lay_down_config_path"] + return Primaite(training_config_path, lay_down_config_path, []) + + +def get_ppo_config( + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + framework: Optional[DLFramework] = DLFramework.TORCH +) -> PPOConfig(): + # Register environment + register_env("primaite", env_creator) + + # Setup PPO + config = PPOConfig() + + config_values = training_config.load(training_config_path) + + # Setup our config object to use our environment + config.environment( + env="primaite", + env_config=dict( + training_config_path=training_config_path, + lay_down_config_path=lay_down_config_path + ) + ) + + env_config = config_values + action_type = env_config.action_type + red_agent = env_config.red_agent_identifier + + if red_agent == "RANDOM" and action_type == "NODE": + config.training( + train_batch_size=6000, lr=5e-5 + ) # number of steps in a training iteration + elif red_agent == "RANDOM" and action_type != "NODE": + config.training(train_batch_size=6000, lr=5e-5) + elif red_agent == "CONFIG" and action_type == "NODE": + config.training(train_batch_size=400, lr=5e-5) + elif red_agent == "CONFIG" and action_type != "NONE": + config.training(train_batch_size=500, lr=5e-5) + else: + config.training(train_batch_size=500, lr=5e-5) + + # Decide if you want torch or tensorflow DL framework. Default is "tf" + config.framework(framework=framework.value) + + # Set the log level to DEBUG, INFO, WARN, or ERROR + config.debugging(seed=415, log_level="ERROR") + + # Setup evaluation + # Explicitly set "explore"=False to override default + # config.evaluation( + # evaluation_interval=100, + # evaluation_duration=20, + # # evaluation_duration_unit="timesteps",) #default episodes + # evaluation_config={"explore": False}, + # ) + + # Setup sampling rollout workers + config.rollouts( + num_rollout_workers=4, + num_envs_per_worker=1, + horizon=128, # num parralel workiers + ) # max num steps in an episode + + config.build() # Build config + + return config + + +def train( + num_iterations: int, + config: Optional[PPOConfig] = None, + algo: Optional[Algorithm] = None +): + """ + + Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint) + """ + + start_time = time.time() + + if algo is None: + algo = config.build() + elif config is None: + config = algo.get_config() + + print(f"Algorithm type: {type(algo)}") + + # iterations are not the same as episodes. + for i in range(num_iterations): + result = algo.train() + # # Save every 10 iterations or after last iteration in training + # if (i % 100 == 0) or (i == num_iterations - 1): + print( + f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}") + # save checkpoint file + checkpoint_file = algo.save("./") + print(f"Checkpoint saved at {checkpoint_file}") + + # convert num_iterations to num_episodes + num_episodes = len( + result["hist_stats"]["episode_lengths"]) * num_iterations + # convert num_iterations to num_timesteps + num_timesteps = sum( + result["hist_stats"]["episode_lengths"] * num_iterations) + # calculate number of wins + + # train time + print(f"Training took {time.time() - start_time:.2f} seconds") + print( + f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}") + return result + + +def load_model_from_checkpoint(config, checkpoint=None): + # create an empty Algorithm + algo = config.build() + + if checkpoint is None: + # Get the checkpoint with the highest iteration number + checkpoint = get_most_recent_checkpoint(config) + + # restore the agent from the checkpoint + algo.restore(checkpoint) + + return algo + + +def get_most_recent_checkpoint(config): + """ + Get the most recent checkpoint for specified action type, red agent and algorithm + """ + + env_config = list(config.env_config.values())[0] + action_type = env_config.action_type + red_agent = env_config.red_agent_identifier + algo_name = config.algo_class.__name__ + + # Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model + relevant_checkpoints = glob.glob( + f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*" + ) + checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints] + max_checkpoint = str(max(checkpoint_numbers)) + checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint + checkpoint = ( + relevant_checkpoints[0].split("_")[0] + + "_" + + checkpoint_number_to_use + + "/rllib_checkpoint.json" + ) + + return checkpoint diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py new file mode 100644 index 00000000..cb12210c --- /dev/null +++ b/src/primaite/agents/sb3.py @@ -0,0 +1,28 @@ +# from typing import Optional +# +# from primaite.agents.agent_abc import AgentABC +# from primaite.environment.primaite_env import Primaite +# +# +# class SB3PPO(AgentABC): +# def __init__(self, env: Primaite): +# super().__init__(env) +# +# def _setup(self): +# if self._env.training_config +# pass +# +# def learn(self, time_steps: Optional[int], episodes: Optional[int]): +# pass +# +# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): +# pass +# +# def load(self): +# pass +# +# def save(self): +# pass +# +# def export(self): +# pass \ No newline at end of file diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 68ad80f2..121beb60 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -79,6 +79,33 @@ class Protocol(Enum): NONE = 7 +class SessionType(Enum): + "The type of PrimAITE Session to be run." + TRAINING = 1 + EVALUATION = 2 + BOTH = 3 + + +class VerboseLevel(Enum): + """PrimAITE Session Output verbose level.""" + NO_OUTPUT = 0 + INFO = 1 + DEBUG = 2 + + +class AgentFramework(Enum): + NONE = 0 + SB3 = 1 + RLLIB = 2 + + +class RedAgentIdentifier(Enum): + A2C = 1 + PPO = 2 + HARDCODED = 3 + RANDOM = 4 + + class ActionType(Enum): """Action type enumeration.""" diff --git a/src/primaite/common/training_config.py b/src/primaite/common/training_config.py deleted file mode 100644 index d45bedf9..00000000 --- a/src/primaite/common/training_config.py +++ /dev/null @@ -1,95 +0,0 @@ -# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -# """The config class.""" -# from dataclasses import dataclass -# -# from primaite.common.enums import ActionType -# -# -# @dataclass() -# class TrainingConfig: -# """Class to hold main config values.""" -# -# # Generic -# agent_identifier: str # The Red Agent algo/class to be used -# action_type: ActionType # type of action to use (NODE/ACL/ANY) -# num_episodes: int # number of episodes to train over -# num_steps: int # number of steps in an episode -# time_delay: int # delay between steps (ms) - applies to generic agents only -# # file -# session_type: str # the session type to run (TRAINING or EVALUATION) -# load_agent: str # Determine whether to load an agent from file -# agent_load_file: str # File path and file name of agent if you're loading one in -# -# # Environment -# observation_space_high_value: int # The high value for the observation space -# -# # Reward values -# # Generic -# all_ok: int -# # Node Hardware State -# off_should_be_on: int -# off_should_be_resetting: int -# on_should_be_off: int -# on_should_be_resetting: int -# resetting_should_be_on: int -# resetting_should_be_off: int -# resetting: int -# # Node Software or Service State -# good_should_be_patching: int -# good_should_be_compromised: int -# good_should_be_overwhelmed: int -# patching_should_be_good: int -# patching_should_be_compromised: int -# patching_should_be_overwhelmed: int -# patching: int -# compromised_should_be_good: int -# compromised_should_be_patching: int -# compromised_should_be_overwhelmed: int -# compromised: int -# overwhelmed_should_be_good: int -# overwhelmed_should_be_patching: int -# overwhelmed_should_be_compromised: int -# overwhelmed: int -# # Node File System State -# good_should_be_repairing: int -# good_should_be_restoring: int -# good_should_be_corrupt: int -# good_should_be_destroyed: int -# repairing_should_be_good: int -# repairing_should_be_restoring: int -# repairing_should_be_corrupt: int -# repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore -# -# repairing: int -# restoring_should_be_good: int -# restoring_should_be_repairing: int -# restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption) -# -# restoring_should_be_destroyed: int -# restoring: int -# corrupt_should_be_good: int -# corrupt_should_be_repairing: int -# corrupt_should_be_restoring: int -# corrupt_should_be_destroyed: int -# corrupt: int -# destroyed_should_be_good: int -# destroyed_should_be_repairing: int -# destroyed_should_be_restoring: int -# destroyed_should_be_corrupt: int -# destroyed: int -# scanning: int -# # IER status -# red_ier_running: int -# green_ier_blocked: int -# -# # Patching / Reset -# os_patching_duration: int # The time taken to patch the OS -# node_reset_duration: int # The time taken to reset a node (hardware) -# node_booting_duration = 0 # The Time taken to turn on the node -# node_shutdown_duration = 0 # The time taken to turn off the node -# service_patching_duration: int # The time taken to patch a service -# file_system_repairing_limit: int # The time take to repair a file -# file_system_restoring_limit: int # The time take to restore a file -# file_system_scanning_limit: int # The time taken to scan the file system -# # Patching / Reset -# diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 46389297..4fd2142e 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -2,6 +2,8 @@ from pathlib import Path from typing import Final +import networkx + from primaite import USERS_CONFIG_DIR, getLogger _LOGGER = getLogger(__name__) @@ -9,6 +11,12 @@ _LOGGER = getLogger(__name__) _EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" +# class LayDownConfig: +# network: networkx.Graph +# POL +# EIR +# ACL + def ddos_basic_one_config_path() -> Path: """ The path to the example lay_down_config_1_DDOS_basic.yaml file. diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4af36abe..b0956d42 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,4 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +from __future__ import annotations + from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Final, Union, Optional @@ -6,7 +8,8 @@ from typing import Any, Dict, Final, Union, Optional import yaml from primaite import USERS_CONFIG_DIR, getLogger -from primaite.common.enums import ActionType +from primaite.common.enums import ActionType, RedAgentIdentifier, \ + AgentFramework, SessionType _LOGGER = getLogger(__name__) @@ -16,10 +19,11 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training @dataclass() class TrainingConfig: """The Training Config class.""" + agent_framework: AgentFramework = AgentFramework.SB3 + "The agent framework." - # Generic - agent_identifier: str = "STABLE_BASELINES3_A2C" - "The Red Agent algo/class to be used." + red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO + "The red agent/algo class." action_type: ActionType = ActionType.ANY "The ActionType to use." @@ -38,8 +42,8 @@ class TrainingConfig: "The delay between steps (ms). Applies to generic agents only." # file - session_type: str = "TRAINING" - "the session type to run (TRAINING or EVALUATION)" + session_type: SessionType = SessionType.TRAINING + "The type of PrimAITE session to run." load_agent: str = False "Determine whether to load an agent from file." @@ -137,6 +141,24 @@ class TrainingConfig: file_system_scanning_limit: int = 5 "The time taken to scan the file system." + @classmethod + def from_dict( + cls, + config_dict: Dict[str, Union[str, int, bool]] + ) -> TrainingConfig: + field_enum_map = { + "agent_framework": AgentFramework, + "red_agent_identifier": RedAgentIdentifier, + "action_type": ActionType, + "session_type": SessionType + } + + for field, enum_class in field_enum_map.items(): + if field in config_dict: + config_dict[field] = enum_class[field] + + return TrainingConfig(**config_dict) + def to_dict(self, json_serializable: bool = True): """ Serialise the ``TrainingConfig`` as dict. @@ -196,10 +218,8 @@ def load(file_path: Union[str, Path], f"from legacy format. Attempting to use file as is." ) _LOGGER.error(msg) - # Convert values to Enums - config["action_type"] = ActionType[config["action_type"]] try: - return TrainingConfig(**config) + return TrainingConfig.from_dict(**config) except TypeError as e: msg = ( f"Error when creating an instance of {TrainingConfig} " @@ -214,22 +234,30 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + agent_framework: AgentFramework = AgentFramework.SB3, + red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO, + action_type: ActionType = ActionType.ANY, + num_steps: int = 256 ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. :param legacy_config_dict: A legacy training config dict. - :param num_steps: The number of steps to set as legacy training configs - don't have num_steps values. + :param agent_framework: The agent framework to use as legacy training + configs don't have agent_framework values. + :param red_agent_identifier: The red agent identifier to use as legacy + training configs don't have red_agent_identifier values. :param action_type: The action space type to set as legacy training configs don't have action_type values. + :param num_steps: The number of steps to set as legacy training configs + don't have num_steps values. :return: The converted training config dict. """ config_dict = { - "num_steps": num_steps, - "action_type": action_type + "agent_framework": agent_framework.name, + "red_agent_identifier": red_agent_identifier.name, + "action_type": action_type.name, + "num_steps": num_steps } for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) @@ -246,7 +274,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str: :return: The mapped key. """ key_mapping = { - "agentIdentifier": "agent_identifier", + "agentIdentifier": None, "numEpisodes": "num_episodes", "timeDelay": "time_delay", "configFilename": None, diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py new file mode 100644 index 00000000..3957e822 --- /dev/null +++ b/src/primaite/primaite_session.py @@ -0,0 +1,216 @@ +from __future__ import annotations + +import json +from datetime import datetime +from pathlib import Path +from typing import Final, Optional, Union +from uuid import uuid4 + +from primaite import getLogger, SESSIONS_DIR +from primaite.config.training_config import TrainingConfig +from primaite.environment.primaite_env import Primaite + +_LOGGER = getLogger(__name__) + + +def _get_session_path(session_timestamp: datetime) -> Path: + """ + Get the directory path the session will output to. + + This is set in the format of: + ~/primaite/sessions//_. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = SESSIONS_DIR / date_dir / session_dir + session_path.mkdir(exist_ok=True, parents=True) + _LOGGER.debug(f"Created PrimAITE Session path: {session_path}") + + return session_path + + +class PrimaiteSession: + + def __init__( + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + auto: bool = True + ): + if not isinstance(training_config_path, Path): + training_config_path = Path(training_config_path) + self._training_config_path: Final[Union[Path]] = training_config_path + + if not isinstance(lay_down_config_path, Path): + lay_down_config_path = Path(lay_down_config_path) + self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path + + self._auto: Final[bool] = auto + + self._uuid: str = str(uuid4()) + self._session_timestamp: Final[datetime] = datetime.now() + self._session_path: Final[Path] = _get_session_path( + self._session_timestamp + ) + self._timestamp_str: Final[str] = self._session_timestamp.strftime( + "%Y-%m-%d_%H-%M-%S") + self._metadata_path = self._session_path / "session_metadata.json" + + + self._env = None + self._training_config = None + self._can_learn: bool = False + _LOGGER.debug("") + + if self._auto: + self.setup() + self.learn() + + @property + def uuid(self): + """The session UUID.""" + return self._uuid + + def _setup_primaite_env(self, transaction_list: Optional[list] = None): + if not transaction_list: + transaction_list = [] + self._env: Primaite = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + transaction_list=transaction_list, + session_path=self._session_path, + timestamp_str=self._timestamp_str + ) + self._training_config: TrainingConfig = self._env.training_config + + def _write_session_metadata_file(self): + """ + Write the ``session_metadata.json`` file. + + Creates a ``session_metadata.json`` in the ``session_dir`` directory + and adds the following key/value pairs: + + - uuid: The UUID assigned to the session upon instantiation. + - start_datetime: The date & time the session started in iso format. + - end_datetime: NULL. + - total_episodes: NULL. + - total_time_steps: NULL. + - env: + - training_config: + - All training config items + - lay_down_config: + - All lay down config items + """ + metadata_dict = { + "uuid": self._uuid, + "start_datetime": self._session_timestamp.isoformat(), + "end_datetime": None, + "total_episodes": None, + "total_time_steps": None, + "env": { + "training_config": self._env.training_config.to_dict( + json_serializable=True + ), + "lay_down_config": self._env.lay_down_config, + }, + } + _LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}") + with open(self._metadata_path, "w") as file: + json.dump(metadata_dict, file) + + def _update_session_metadata_file(self): + """ + Update the ``session_metadata.json`` file. + + Updates the `session_metadata.json`` in the ``session_dir`` directory + with the following key/value pairs: + + - end_datetime: NULL. + - total_episodes: NULL. + - total_time_steps: NULL. + """ + with open(self._metadata_path, "r") as file: + metadata_dict = json.load(file) + + metadata_dict["end_datetime"] = datetime.now().isoformat() + metadata_dict["total_episodes"] = self._env.episode_count + metadata_dict["total_time_steps"] = self._env.total_step_count + + _LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}") + with open(self._metadata_path, "w") as file: + json.dump(metadata_dict, file) + + def setup(self): + self._setup_primaite_env() + self._can_learn = True + pass + + def learn( + self, + time_steps: Optional[int], + episodes: Optional[int], + iterations: Optional[int], + **kwargs + ): + if self._can_learn: + # Run environment against an agent + if self._training_config.agent_identifier == "GENERIC": + run_generic(env=env, config_values=config_values) + elif self._training_config == "STABLE_BASELINES3_PPO": + run_stable_baselines3_ppo( + env=env, + config_values=config_values, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + elif self._training_config == "STABLE_BASELINES3_A2C": + run_stable_baselines3_a2c( + env=env, + config_values=config_values, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + + print("Session finished") + _LOGGER.debug("Session finished") + + print("Saving transaction logs...") + write_transaction_to_file( + transaction_list=transaction_list, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + + print("Updating Session Metadata file...") + _update_session_metadata_file(session_dir=session_dir, env=env) + + print("Finished") + _LOGGER.debug("Finished") + + def evaluate( + self, + time_steps: Optional[int], + episodes: Optional[int], + **kwargs + ): + pass + + def export(self): + pass + + @classmethod + def import_agent( + cls, + gent_path: str, + training_config_path: str, + lay_down_config_path: str + ) -> PrimaiteSession: + session = PrimaiteSession(training_config_path, lay_down_config_path) + + # Reset the UUID + session._uuid = "" + + return session diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy/new_training_config.yaml index becc1799..44897bfa 100644 --- a/tests/config/legacy/new_training_config.yaml +++ b/tests/config/legacy/new_training_config.yaml @@ -1,11 +1,20 @@ # Main Config File # Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C + +# Sets which agent algorithm framework will be used: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray[RLlib]) +# "NONE" (Custom Agent) +agent_framework: RLLIB + +# Sets which Red Agent algo/class will be used: +# "PPO" (Proximal Policy Optimization) +# "A2C" (Advantage Actor Critic) +# "HARDCODED" (Custom Agent) +# "RANDOM" (Random Action) +red_agent_identifier: PPO + # Sets How the Action Space is defined: # "NODE" # "ACL" From 68499392655a8fb8538748a7eff4903bd817e572 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 15 Jun 2023 09:48:44 +0100 Subject: [PATCH 02/43] #917 - started working on the Agent abstract classes and sub-classes --- docs/source/config.rst | 7 --- src/primaite/agents/agent_abc.py | 41 +++++++++++++++- src/primaite/agents/sb3.py | 63 +++++++++++++----------- src/primaite/primaite_session.py | 84 +++++++++++++++++++++++++------- 4 files changed, 141 insertions(+), 54 deletions(-) diff --git a/docs/source/config.rst b/docs/source/config.rst index 81468f17..1bea0671 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -20,13 +20,6 @@ The environment config file consists of the following attributes: **Generic Config Values** -* **agent_identifier** [enum] - - This identifies the agent to use for the session. Select from one of the following: - - * GENERIC - Where a user developed agent is to be used - * STABLE_BASELINES3_PPO - Use a SB3 PPO agent - * STABLE_BASELINES3_A2C - use a SB3 A2C agent * **agent_framework** [enum] diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index c500128d..c9067210 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -1,14 +1,20 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Optional, Final, Dict, Any +from primaite import getLogger +from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite +_LOGGER = getLogger(__name__) + class AgentABC(ABC): @abstractmethod def __init__(self, env: Primaite): self._env: Primaite = env + self._training_config: Final[TrainingConfig] = self._env.training_config + self._lay_down_config: Dict[str, Any] = self._env.lay_down_config self._agent = None @abstractmethod @@ -33,4 +39,35 @@ class AgentABC(ABC): @abstractmethod def export(self): - pass \ No newline at end of file + pass + + +class DeterministicAgentABC(AgentABC): + @abstractmethod + def __init__(self, env: Primaite): + self._env: Primaite = env + self._agent = None + + @abstractmethod + def _setup(self): + pass + + def learn(self, time_steps: Optional[int], episodes: Optional[int]): + pass + _LOGGER.warning("Deterministic agents cannot learn") + + @abstractmethod + def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + @abstractmethod + def load(self): + pass + + @abstractmethod + def save(self): + pass + + @abstractmethod + def export(self): + pass diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index cb12210c..7d0fba3b 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,28 +1,35 @@ -# from typing import Optional -# -# from primaite.agents.agent_abc import AgentABC -# from primaite.environment.primaite_env import Primaite -# -# -# class SB3PPO(AgentABC): -# def __init__(self, env: Primaite): -# super().__init__(env) -# -# def _setup(self): -# if self._env.training_config -# pass -# -# def learn(self, time_steps: Optional[int], episodes: Optional[int]): -# pass -# -# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): -# pass -# -# def load(self): -# pass -# -# def save(self): -# pass -# -# def export(self): -# pass \ No newline at end of file +from typing import Optional + +from stable_baselines3 import PPO + +from primaite.agents.agent_abc import AgentABC +from primaite.environment.primaite_env import Primaite +from stable_baselines3.ppo import MlpPolicy as PPOMlp + +class SB3PPO(AgentABC): + def __init__(self, env: Primaite): + super().__init__(env) + + def _setup(self): + self._agent = PPO( + PPOMlp, + self._env, + verbose=0, + n_steps=self._training_config.num_steps + ) + + + def learn(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): + pass + + def load(self): + pass + + def save(self): + pass + + def export(self): + pass \ No newline at end of file diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 3957e822..0efc0acf 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -7,6 +7,8 @@ from typing import Final, Optional, Union from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR +from primaite.common.enums import AgentFramework, RedAgentIdentifier, \ + ActionType from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite @@ -61,7 +63,7 @@ class PrimaiteSession: self._env = None - self._training_config = None + self._training_config: TrainingConfig self._can_learn: bool = False _LOGGER.debug("") @@ -157,22 +159,70 @@ class PrimaiteSession: ): if self._can_learn: # Run environment against an agent - if self._training_config.agent_identifier == "GENERIC": - run_generic(env=env, config_values=config_values) - elif self._training_config == "STABLE_BASELINES3_PPO": - run_stable_baselines3_ppo( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - elif self._training_config == "STABLE_BASELINES3_A2C": - run_stable_baselines3_a2c( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) + if self._training_config.agent_framework == AgentFramework.NONE: + if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM: + # Stochastic Random Agent + run_generic(env=env, config_values=config_values) + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED: + if self._training_config.action_type == ActionType.NODE: + # Deterministic Hardcoded Agent with Node Action Space + pass + + elif self._training_config.action_type == ActionType.ACL: + # Deterministic Hardcoded Agent with ACL Action Space + pass + + elif self._training_config.action_type == ActionType.ANY: + # Deterministic Hardcoded Agent with ANY Action Space + pass + + else: + # Invalid RedAgentIdentifier ActionType combo + pass + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + + elif self._training_config.agent_framework == AgentFramework.SB3: + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + # Stable Baselines3/Proximal Policy Optimization + run_stable_baselines3_ppo( + env=env, + config_values=config_values, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + # Stable Baselines3/Advantage Actor Critic + run_stable_baselines3_a2c( + env=env, + config_values=config_values, + session_path=session_dir, + timestamp_str=timestamp_str, + ) + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + + elif self._training_config.agent_framework == AgentFramework.RLLIB: + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + # Ray RLlib/Proximal Policy Optimization + pass + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + # Ray RLlib/Advantage Actor Critic + pass + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + else: + # Invalid AgentFramework + pass print("Session finished") _LOGGER.debug("Session finished") From c2c396052f8e46350210a6ba542dc140d696995a Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Sun, 18 Jun 2023 22:40:56 +0100 Subject: [PATCH 03/43] #917 - Got RLlib fully training in PrimAITE. Started integrating the the other agents into the Session class --- src/primaite/VERSION | 2 +- src/primaite/agents/agent_abc.py | 83 ++++++- src/primaite/agents/rllib.py | 228 +++++++----------- src/primaite/agents/sb3.py | 64 ++++- src/primaite/common/enums.py | 17 ++ .../training/training_config_main.yaml | 38 ++- src/primaite/config/training_config.py | 17 +- src/primaite/environment/primaite_env.py | 6 +- src/primaite/main.py | 48 ---- 9 files changed, 274 insertions(+), 229 deletions(-) diff --git a/src/primaite/VERSION b/src/primaite/VERSION index bd82b28c..4111d137 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0dev0 +2.0.0rc1 diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py index c9067210..d5aceeaf 100644 --- a/src/primaite/agents/agent_abc.py +++ b/src/primaite/agents/agent_abc.py @@ -1,36 +1,84 @@ from abc import ABC, abstractmethod -from typing import Optional, Final, Dict, Any +from datetime import datetime +from pathlib import Path +from typing import Optional, Final, Dict, Any, Union, Tuple + +import yaml from primaite import getLogger -from primaite.config.training_config import TrainingConfig +from primaite.config.training_config import TrainingConfig, load from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) +def _get_temp_session_path(session_timestamp: datetime) -> Path: + """ + Get a temp directory session path the test session will output to. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = Path("./") / date_dir / session_dir + session_path.mkdir(exist_ok=True, parents=True) + + return session_path + + class AgentABC(ABC): @abstractmethod - def __init__(self, env: Primaite): - self._env: Primaite = env - self._training_config: Final[TrainingConfig] = self._env.training_config - self._lay_down_config: Dict[str, Any] = self._env.lay_down_config + def __init__( + self, + training_config_path, + lay_down_config_path + ): + self._training_config_path = training_config_path + self._training_config: Final[TrainingConfig] = load( + self._training_config_path + ) + self._lay_down_config_path = lay_down_config_path + self._env: Primaite self._agent = None + self.session_timestamp: datetime = datetime.now() + self.session_path = _get_temp_session_path(self.session_timestamp) + + self.timestamp_str = self.session_timestamp.strftime( + "%Y-%m-%d_%H-%M-%S") @abstractmethod def _setup(self): pass @abstractmethod - def learn(self, time_steps: Optional[int], episodes: Optional[int]): + def _save_checkpoint(self): pass @abstractmethod - def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): pass @abstractmethod - def load(self): + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + pass + + @abstractmethod + def _get_latest_checkpoint(self): + pass + + @classmethod + @abstractmethod + def load(cls): pass @abstractmethod @@ -44,14 +92,24 @@ class AgentABC(ABC): class DeterministicAgentABC(AgentABC): @abstractmethod - def __init__(self, env: Primaite): - self._env: Primaite = env + def __init__( + self, + training_config_path, + lay_down_config_path + ): + self._training_config_path = training_config_path + self._lay_down_config_path = lay_down_config_path + self._env: Primaite self._agent = None @abstractmethod def _setup(self): pass + @abstractmethod + def _get_latest_checkpoint(self): + pass + def learn(self, time_steps: Optional[int], episodes: Optional[int]): pass _LOGGER.warning("Deterministic agents cannot learn") @@ -60,8 +118,9 @@ class DeterministicAgentABC(AgentABC): def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): pass + @classmethod @abstractmethod - def load(self): + def load(cls): pass @abstractmethod diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index d07265b4..bb0daefb 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -8,170 +8,106 @@ from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.registry import register_env +from primaite.agents.agent_abc import AgentABC from primaite.config import training_config from primaite.environment.primaite_env import Primaite -class DLFramework(Enum): - """The DL Frameworks enumeration.""" - TF = "tf" - TF2 = "tf2" - TORCH = "torch" +def _env_creator(env_config): + return Primaite( + training_config_path=env_config["training_config_path"], + lay_down_config_path=env_config["lay_down_config_path"], + transaction_list=env_config["transaction_list"], + session_path=env_config["session_path"], + timestamp_str=env_config["timestamp_str"] + ) -def env_creator(env_config): - training_config_path = env_config["training_config_path"] - lay_down_config_path = env_config["lay_down_config_path"] - return Primaite(training_config_path, lay_down_config_path, []) +class RLlibPPO(AgentABC): + def __init__( + self, + training_config_path, + lay_down_config_path + ): + super().__init__(training_config_path, lay_down_config_path) + self._ppo_config: PPOConfig + self._current_result: dict + self._setup() -def get_ppo_config( - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - framework: Optional[DLFramework] = DLFramework.TORCH -) -> PPOConfig(): - # Register environment - register_env("primaite", env_creator) + def _setup(self): + register_env("primaite", _env_creator) + self._ppo_config = PPOConfig() - # Setup PPO - config = PPOConfig() - - config_values = training_config.load(training_config_path) - - # Setup our config object to use our environment - config.environment( - env="primaite", - env_config=dict( - training_config_path=training_config_path, - lay_down_config_path=lay_down_config_path + self._ppo_config.environment( + env="primaite", + env_config=dict( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + transaction_list=[], + session_path=self.session_path, + timestamp_str=self.timestamp_str + ) ) - ) - env_config = config_values - action_type = env_config.action_type - red_agent = env_config.red_agent_identifier + self._ppo_config.training( + train_batch_size=self._training_config.num_steps + ) + self._ppo_config.framework( + framework=self._training_config.deep_learning_framework.value + ) - if red_agent == "RANDOM" and action_type == "NODE": - config.training( - train_batch_size=6000, lr=5e-5 - ) # number of steps in a training iteration - elif red_agent == "RANDOM" and action_type != "NODE": - config.training(train_batch_size=6000, lr=5e-5) - elif red_agent == "CONFIG" and action_type == "NODE": - config.training(train_batch_size=400, lr=5e-5) - elif red_agent == "CONFIG" and action_type != "NONE": - config.training(train_batch_size=500, lr=5e-5) - else: - config.training(train_batch_size=500, lr=5e-5) + self._ppo_config.rollouts( + num_rollout_workers=1, + num_envs_per_worker=1, + horizon=self._training_config.num_steps + ) + self._agent: Algorithm = self._ppo_config.build() - # Decide if you want torch or tensorflow DL framework. Default is "tf" - config.framework(framework=framework.value) + def _save_checkpoint(self): + checkpoint_n = self._training_config.checkpoint_every_n_episodes + episode_count = self._current_result["episodes_total"] + if checkpoint_n > 0 and episode_count > 0: + if ( + (episode_count % checkpoint_n == 0) + or (episode_count == self._training_config.num_episodes) + ): + self._agent.save(self.session_path) - # Set the log level to DEBUG, INFO, WARN, or ERROR - config.debugging(seed=415, log_level="ERROR") + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + # Temporarily override train_batch_size and horizon + if time_steps: + self._ppo_config.train_batch_size = time_steps + self._ppo_config.horizon = time_steps - # Setup evaluation - # Explicitly set "explore"=False to override default - # config.evaluation( - # evaluation_interval=100, - # evaluation_duration=20, - # # evaluation_duration_unit="timesteps",) #default episodes - # evaluation_config={"explore": False}, - # ) + if not episodes: + episodes = self._training_config.num_episodes - # Setup sampling rollout workers - config.rollouts( - num_rollout_workers=4, - num_envs_per_worker=1, - horizon=128, # num parralel workiers - ) # max num steps in an episode + for i in range(episodes): + self._current_result = self._agent.train() + self._save_checkpoint() + self._agent.stop() - config.build() # Build config + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + raise NotImplementedError - return config + def _get_latest_checkpoint(self): + raise NotImplementedError + @classmethod + def load(cls): + raise NotImplementedError -def train( - num_iterations: int, - config: Optional[PPOConfig] = None, - algo: Optional[Algorithm] = None -): - """ + def save(self): + raise NotImplementedError - Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint) - """ - - start_time = time.time() - - if algo is None: - algo = config.build() - elif config is None: - config = algo.get_config() - - print(f"Algorithm type: {type(algo)}") - - # iterations are not the same as episodes. - for i in range(num_iterations): - result = algo.train() - # # Save every 10 iterations or after last iteration in training - # if (i % 100 == 0) or (i == num_iterations - 1): - print( - f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}") - # save checkpoint file - checkpoint_file = algo.save("./") - print(f"Checkpoint saved at {checkpoint_file}") - - # convert num_iterations to num_episodes - num_episodes = len( - result["hist_stats"]["episode_lengths"]) * num_iterations - # convert num_iterations to num_timesteps - num_timesteps = sum( - result["hist_stats"]["episode_lengths"] * num_iterations) - # calculate number of wins - - # train time - print(f"Training took {time.time() - start_time:.2f} seconds") - print( - f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}") - return result - - -def load_model_from_checkpoint(config, checkpoint=None): - # create an empty Algorithm - algo = config.build() - - if checkpoint is None: - # Get the checkpoint with the highest iteration number - checkpoint = get_most_recent_checkpoint(config) - - # restore the agent from the checkpoint - algo.restore(checkpoint) - - return algo - - -def get_most_recent_checkpoint(config): - """ - Get the most recent checkpoint for specified action type, red agent and algorithm - """ - - env_config = list(config.env_config.values())[0] - action_type = env_config.action_type - red_agent = env_config.red_agent_identifier - algo_name = config.algo_class.__name__ - - # Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model - relevant_checkpoints = glob.glob( - f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*" - ) - checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints] - max_checkpoint = str(max(checkpoint_numbers)) - checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint - checkpoint = ( - relevant_checkpoints[0].split("_")[0] - + "_" - + checkpoint_number_to_use - + "/rllib_checkpoint.json" - ) - - return checkpoint + def export(self): + raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 7d0fba3b..8fbbd815 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -6,30 +6,74 @@ from primaite.agents.agent_abc import AgentABC from primaite.environment.primaite_env import Primaite from stable_baselines3.ppo import MlpPolicy as PPOMlp + class SB3PPO(AgentABC): - def __init__(self, env: Primaite): - super().__init__(env) + def __init__( + self, + training_config_path, + lay_down_config_path + ): + super().__init__(training_config_path, lay_down_config_path) + self._tensorboard_log_path = self.session_path / "tensorboard_logs" + self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) def _setup(self): + self._env = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + transaction_list=[], + session_path=self.session_path, + timestamp_str=self.timestamp_str + ) self._agent = PPO( PPOMlp, self._env, verbose=0, - n_steps=self._training_config.num_steps + n_steps=self._training_config.num_steps, + tensorboard_log=self._tensorboard_log_path ) + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + if not time_steps: + time_steps = self._training_config.num_steps - def learn(self, time_steps: Optional[int], episodes: Optional[int]): - pass + if not episodes: + episodes = self._training_config.num_episodes - def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): - pass + for i in range(episodes): + self._agent.learn(total_timesteps=time_steps) + + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + deterministic: bool = True + ): + if not time_steps: + time_steps = self._training_config.num_steps + + if not episodes: + episodes = self._training_config.num_episodes + + for episode in range(episodes): + obs = self._env.reset() + + for step in range(time_steps): + action, _states = self._agent.predict( + obs, + deterministic=deterministic + ) + obs, rewards, done, info = self._env.step(action) def load(self): - pass + raise NotImplementedError def save(self): - pass + raise NotImplementedError def export(self): - pass \ No newline at end of file + raise NotImplementedError diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 121beb60..f28916c2 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -95,15 +95,32 @@ class VerboseLevel(Enum): class AgentFramework(Enum): NONE = 0 + "Custom Agent" SB3 = 1 + "Stable Baselines3" RLLIB = 2 + "Ray RLlib" + + +class DeepLearningFramework(Enum): + """The deep learning framework enumeration.""" + TF = "tf" + "Tensorflow" + TF2 = "tf2" + "Tensorflow 2.x" + TORCH = "torch" + "PyTorch" class RedAgentIdentifier(Enum): A2C = 1 + "Advantage Actor Critic" PPO = 2 + "Proximal Policy Optimization" HARDCODED = 3 + "Custom Agent" RANDOM = 4 + "Custom Agent" class ActionType(Enum): diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..ebee7f77 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -1,26 +1,52 @@ # Main Config File -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +# Sets which agent algorithm framework will be used: +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "NONE" (Custom Agent) +agent_framework: RLLIB + +# Sets which deep learning framework will be used. Default is TF (Tensorflow). +# Options are: +# "TF" (Tensorflow) +# TF2 (Tensorflow 2.X) +# TORCH (PyTorch) +deep_learning_framework: TORCH + +# Sets which Red Agent algo/class will be used: +# Options are: +# "A2C" (Advantage Actor Critic) +# "PPO" (Proximal Policy Optimization) +# "HARDCODED" (Custom Agent) +# "RANDOM" (Random Action) +red_agent_identifier: PPO + # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions action_type: NODE + # Number of episodes to run per session num_episodes: 10 + # Number of time_steps per episode num_steps: 256 + +# Sets how often the agent will save a checkpoint (every n time episodes). +# Set to 0 if no checkpoints are required. Default is 10 +checkpoint_every_n_episodes: 5 + # Time delay between steps (for generic agents) time_delay: 10 + # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING + # Determine whether to load an agent from file load_agent: False + # File path and file name of agent if you're loading one in agent_load_file: C:\[Path]\[agent_saved_filename.zip] diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index b0956d42..c2cb8db9 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -8,6 +8,7 @@ from typing import Any, Dict, Final, Union, Optional import yaml from primaite import USERS_CONFIG_DIR, getLogger +from primaite.common.enums import DeepLearningFramework from primaite.common.enums import ActionType, RedAgentIdentifier, \ AgentFramework, SessionType @@ -20,10 +21,13 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training class TrainingConfig: """The Training Config class.""" agent_framework: AgentFramework = AgentFramework.SB3 - "The agent framework." + "The AgentFramework" + + deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF + "The DeepLearningFramework." red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO - "The red agent/algo class." + "The RedAgentIdentifier.." action_type: ActionType = ActionType.ANY "The ActionType to use." @@ -33,6 +37,10 @@ class TrainingConfig: num_steps: int = 256 "The number of steps in an episode." + + checkpoint_every_n_episodes: int = 5 + "The agent will save a checkpoint every n episodes." + observation_space: dict = field( default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]} ) @@ -148,6 +156,7 @@ class TrainingConfig: ) -> TrainingConfig: field_enum_map = { "agent_framework": AgentFramework, + "deep_learning_framework": DeepLearningFramework, "red_agent_identifier": RedAgentIdentifier, "action_type": ActionType, "session_type": SessionType @@ -155,7 +164,7 @@ class TrainingConfig: for field, enum_class in field_enum_map.items(): if field in config_dict: - config_dict[field] = enum_class[field] + config_dict[field] = enum_class[config_dict[field]] return TrainingConfig(**config_dict) @@ -219,7 +228,7 @@ def load(file_path: Union[str, Path], ) _LOGGER.error(msg) try: - return TrainingConfig.from_dict(**config) + return TrainingConfig.from_dict(config) except TypeError as e: msg = ( f"Error when creating an instance of {TrainingConfig} " diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..e0cfb119 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -5,7 +5,7 @@ import csv import logging from datetime import datetime from pathlib import Path -from typing import Dict, Tuple, Union +from typing import Dict, Tuple, Union, Final import networkx as nx import numpy as np @@ -77,6 +77,8 @@ class Primaite(Env): :param timestamp_str: The session timestamp in the format: _. """ + self.session_path: Final[Path] = session_path + self.timestamp_str: Final[str] = timestamp_str self._training_config_path = training_config_path self._lay_down_config_path = lay_down_config_path @@ -93,7 +95,7 @@ class Primaite(Env): self.transaction_list = transaction_list # The agent in use - self.agent_identifier = self.training_config.agent_identifier + self.agent_identifier = self.training_config.red_agent_identifier # Create a dictionary to hold all the nodes self.nodes: Dict[str, NodeUnion] = {} diff --git a/src/primaite/main.py b/src/primaite/main.py index ac32a018..842b9259 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -108,54 +108,6 @@ def run_stable_baselines3_ppo( env.close() -def run_stable_baselines3_a2c( - env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -): - """ - Run against a stable_baselines3 A2C agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if config_values.load_agent: - try: - agent = A2C.load( - config_values.agent_load_file, - env, - verbose=0, - n_steps=config_values.num_steps, - ) - except Exception: - print( - "ERROR: Could not load agent at location: " - + config_values.agent_load_file - ) - _LOGGER.error("Could not load agent") - _LOGGER.error("Exception occured", exc_info=True) - else: - agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps) - - if config_values.session_type == "TRAINING": - # We're in a training session - print("Starting training session...") - _LOGGER.debug("Starting training session...") - for episode in range(config_values.num_episodes): - agent.learn(total_timesteps=config_values.num_steps) - _save_agent(agent, session_path, timestamp_str) - else: - # Default to being in an evaluation session - print("Starting evaluation session...") - _LOGGER.debug("Starting evaluation session...") - evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) - - env.close() - - def _write_session_metadata_file( session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite ): From 23bafde4578f76f508aefa5da55d7ab29f7d6931 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 19 Jun 2023 20:27:08 +0100 Subject: [PATCH 04/43] #917 - Integrated both SB3 and RLlib agents into PrimaiteSession --- src/primaite/agents/agent.py | 251 ++++++++ src/primaite/agents/agent_abc.py | 132 ----- src/primaite/agents/rllib.py | 15 +- src/primaite/agents/sb3.py | 33 +- src/primaite/common/enums.py | 12 +- .../training/training_config_main.yaml | 7 + src/primaite/config/lay_down_config.py | 56 +- src/primaite/config/training_config.py | 103 ++-- src/primaite/environment/primaite_env.py | 1 - src/primaite/main.py | 534 ++++++++---------- src/primaite/primaite_session.py | 265 +++------ .../transactions/transactions_to_file.py | 1 + tests/conftest.py | 4 +- 13 files changed, 726 insertions(+), 688 deletions(-) create mode 100644 src/primaite/agents/agent.py delete mode 100644 src/primaite/agents/agent_abc.py diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py new file mode 100644 index 00000000..58158dcb --- /dev/null +++ b/src/primaite/agents/agent.py @@ -0,0 +1,251 @@ +import json +from abc import ABC, abstractmethod +from datetime import datetime +from pathlib import Path +from typing import Optional, Final, Dict, Union, List +from uuid import uuid4 + +from primaite import getLogger +from primaite.common.enums import OutputVerboseLevel +from primaite.config import lay_down_config +from primaite.config import training_config +from primaite.config.training_config import TrainingConfig +from primaite.environment.primaite_env import Primaite +from primaite.transactions.transactions_to_file import \ + write_transaction_to_file + +_LOGGER = getLogger(__name__) + + +def _get_temp_session_path(session_timestamp: datetime) -> Path: + """ + Get a temp directory session path the test session will output to. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = Path("./") / date_dir / session_path + session_path.mkdir(exist_ok=True, parents=True) + + return session_path + + +class AgentSessionABC(ABC): + + @abstractmethod + def __init__( + self, + training_config_path, + lay_down_config_path + ): + if not isinstance(training_config_path, Path): + training_config_path = Path(training_config_path) + self._training_config_path: Final[Union[Path]] = training_config_path + self._training_config: Final[TrainingConfig] = training_config.load( + self._training_config_path + ) + + if not isinstance(lay_down_config_path, Path): + lay_down_config_path = Path(lay_down_config_path) + self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path + self._lay_down_config: Dict = lay_down_config.load( + self._lay_down_config_path + ) + self.output_verbose_level = self._training_config.output_verbose_level + + self._env: Primaite + self._agent = None + self._transaction_list: List[Dict] = [] + self._can_learn: bool = False + self._can_evaluate: bool = False + + self._uuid = str(uuid4()) + self.session_timestamp: datetime = datetime.now() + "The session timestamp" + self.session_path = _get_temp_session_path(self.session_timestamp) + "The Session path" + self.checkpoints_path = self.session_path / "checkpoints" + "The Session checkpoints path" + + self.timestamp_str = self.session_timestamp.strftime( + "%Y-%m-%d_%H-%M-%S") + "The session timestamp as a string" + + @property + def uuid(self): + """The Agent Session UUID.""" + return self._uuid + + def _write_session_metadata_file(self): + """ + Write the ``session_metadata.json`` file. + + Creates a ``session_metadata.json`` in the ``session_path`` directory + and adds the following key/value pairs: + + - uuid: The UUID assigned to the session upon instantiation. + - start_datetime: The date & time the session started in iso format. + - end_datetime: NULL. + - total_episodes: NULL. + - total_time_steps: NULL. + - env: + - training_config: + - All training config items + - lay_down_config: + - All lay down config items + + """ + metadata_dict = { + "uuid": self.uuid, + "start_datetime": self.session_timestamp.isoformat(), + "end_datetime": None, + "total_episodes": None, + "total_time_steps": None, + "env": { + "training_config": self._training_config.to_dict( + json_serializable=True + ), + "lay_down_config": self._lay_down_config, + }, + } + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Writing Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished writing session metadata file") + + def _update_session_metadata_file(self): + """ + Update the ``session_metadata.json`` file. + + Updates the `session_metadata.json`` in the ``session_path`` directory + with the following key/value pairs: + + - end_datetime: The date & time the session ended in iso format. + - total_episodes: The total number of training episodes completed. + - total_time_steps: The total number of training time steps completed. + """ + with open(self.session_path / "session_metadata.json", "r") as file: + metadata_dict = json.load(file) + + metadata_dict["end_datetime"] = datetime.now().isoformat() + metadata_dict["total_episodes"] = self._env.episode_count + metadata_dict["total_time_steps"] = self._env.total_step_count + + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Updating Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished updating session metadata file") + + @abstractmethod + def _setup(self): + if self.output_verbose_level >= OutputVerboseLevel.INFO: + _LOGGER.info( + "Welcome to the Primary-level AI Training Environment " + "(PrimAITE)" + ) + _LOGGER.debug( + f"The output directory for this agent is: {self.session_path}" + ) + self._write_session_metadata_file() + self._can_learn = True + self._can_evaluate = False + + @abstractmethod + def _save_checkpoint(self): + pass + + @abstractmethod + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs + ): + if self._can_learn: + _LOGGER.debug("Writing transactions") + write_transaction_to_file( + transaction_list=self._transaction_list, + session_path=self.session_path, + timestamp_str=self.timestamp_str, + ) + self._update_session_metadata_file() + self._can_evaluate = True + + @abstractmethod + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs + ): + pass + + @abstractmethod + def _get_latest_checkpoint(self): + pass + + @classmethod + @abstractmethod + def load(cls): + pass + + @abstractmethod + def save(self): + self._agent.save(self.session_path) + + @abstractmethod + def export(self): + pass + + +class DeterministicAgentSessionABC(AgentSessionABC): + @abstractmethod + def __init__( + self, + training_config_path, + lay_down_config_path + ): + self._training_config_path = training_config_path + self._lay_down_config_path = lay_down_config_path + self._env: Primaite + self._agent = None + + @abstractmethod + def _setup(self): + pass + + @abstractmethod + def _get_latest_checkpoint(self): + pass + + def learn( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + _LOGGER.warning("Deterministic agents cannot learn") + + @abstractmethod + def evaluate( + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None + ): + pass + + @classmethod + @abstractmethod + def load(cls): + pass + + @abstractmethod + def save(self): + pass + + @abstractmethod + def export(self): + pass diff --git a/src/primaite/agents/agent_abc.py b/src/primaite/agents/agent_abc.py deleted file mode 100644 index d5aceeaf..00000000 --- a/src/primaite/agents/agent_abc.py +++ /dev/null @@ -1,132 +0,0 @@ -from abc import ABC, abstractmethod -from datetime import datetime -from pathlib import Path -from typing import Optional, Final, Dict, Any, Union, Tuple - -import yaml - -from primaite import getLogger -from primaite.config.training_config import TrainingConfig, load -from primaite.environment.primaite_env import Primaite - -_LOGGER = getLogger(__name__) - - -def _get_temp_session_path(session_timestamp: datetime) -> Path: - """ - Get a temp directory session path the test session will output to. - - :param session_timestamp: This is the datetime that the session started. - :return: The session directory path. - """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path("./") / date_dir / session_dir - session_path.mkdir(exist_ok=True, parents=True) - - return session_path - - -class AgentABC(ABC): - - @abstractmethod - def __init__( - self, - training_config_path, - lay_down_config_path - ): - self._training_config_path = training_config_path - self._training_config: Final[TrainingConfig] = load( - self._training_config_path - ) - self._lay_down_config_path = lay_down_config_path - self._env: Primaite - self._agent = None - self.session_timestamp: datetime = datetime.now() - self.session_path = _get_temp_session_path(self.session_timestamp) - - self.timestamp_str = self.session_timestamp.strftime( - "%Y-%m-%d_%H-%M-%S") - - @abstractmethod - def _setup(self): - pass - - @abstractmethod - def _save_checkpoint(self): - pass - - @abstractmethod - def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None - ): - pass - - @abstractmethod - def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None - ): - pass - - @abstractmethod - def _get_latest_checkpoint(self): - pass - - @classmethod - @abstractmethod - def load(cls): - pass - - @abstractmethod - def save(self): - pass - - @abstractmethod - def export(self): - pass - - -class DeterministicAgentABC(AgentABC): - @abstractmethod - def __init__( - self, - training_config_path, - lay_down_config_path - ): - self._training_config_path = training_config_path - self._lay_down_config_path = lay_down_config_path - self._env: Primaite - self._agent = None - - @abstractmethod - def _setup(self): - pass - - @abstractmethod - def _get_latest_checkpoint(self): - pass - - def learn(self, time_steps: Optional[int], episodes: Optional[int]): - pass - _LOGGER.warning("Deterministic agents cannot learn") - - @abstractmethod - def evaluate(self, time_steps: Optional[int], episodes: Optional[int]): - pass - - @classmethod - @abstractmethod - def load(cls): - pass - - @abstractmethod - def save(self): - pass - - @abstractmethod - def export(self): - pass diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index bb0daefb..80318499 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -8,7 +8,7 @@ from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.registry import register_env -from primaite.agents.agent_abc import AgentABC +from primaite.agents.agent import AgentSessionABC from primaite.config import training_config from primaite.environment.primaite_env import Primaite @@ -23,7 +23,7 @@ def _env_creator(env_config): ) -class RLlibPPO(AgentABC): +class RLlibPPO(AgentSessionABC): def __init__( self, @@ -34,8 +34,10 @@ class RLlibPPO(AgentABC): self._ppo_config: PPOConfig self._current_result: dict self._setup() + self._agent.save() def _setup(self): + super()._setup() register_env("primaite", _env_creator) self._ppo_config = PPOConfig() @@ -72,12 +74,13 @@ class RLlibPPO(AgentABC): (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes) ): - self._agent.save(self.session_path) + self._agent.save(self.checkpoints_path) def learn( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): # Temporarily override train_batch_size and horizon if time_steps: @@ -91,11 +94,13 @@ class RLlibPPO(AgentABC): self._current_result = self._agent.train() self._save_checkpoint() self._agent.stop() + super().learn() def evaluate( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 8fbbd815..6e6d8a5d 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,13 +1,14 @@ from typing import Optional +import numpy as np from stable_baselines3 import PPO -from primaite.agents.agent_abc import AgentABC +from primaite.agents.agent import AgentSessionABC from primaite.environment.primaite_env import Primaite from stable_baselines3.ppo import MlpPolicy as PPOMlp -class SB3PPO(AgentABC): +class SB3PPO(AgentSessionABC): def __init__( self, training_config_path, @@ -16,8 +17,10 @@ class SB3PPO(AgentABC): super().__init__(training_config_path, lay_down_config_path) self._tensorboard_log_path = self.session_path / "tensorboard_logs" self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) + self._setup() def _setup(self): + super()._setup() self._env = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, @@ -28,15 +31,30 @@ class SB3PPO(AgentABC): self._agent = PPO( PPOMlp, self._env, - verbose=0, + verbose=1, n_steps=self._training_config.num_steps, tensorboard_log=self._tensorboard_log_path ) + def _save_checkpoint(self): + checkpoint_n = self._training_config.checkpoint_every_n_episodes + episode_count = self._env.episode_count + if checkpoint_n > 0 and episode_count > 0: + if ( + (episode_count % checkpoint_n == 0) + or (episode_count == self._training_config.num_episodes) + ): + self._agent.save( + self.checkpoints_path / f"sb3ppo_{episode_count}.zip") + + def _get_latest_checkpoint(self): + pass + def learn( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): if not time_steps: time_steps = self._training_config.num_steps @@ -46,12 +64,15 @@ class SB3PPO(AgentABC): for i in range(episodes): self._agent.learn(total_timesteps=time_steps) + self._save_checkpoint() + super().learn() def evaluate( self, time_steps: Optional[int] = None, episodes: Optional[int] = None, - deterministic: bool = True + deterministic: bool = True, + **kwargs ): if not time_steps: time_steps = self._training_config.num_steps @@ -67,6 +88,8 @@ class SB3PPO(AgentABC): obs, deterministic=deterministic ) + if isinstance(action, np.ndarray): + action = np.int64(action) obs, rewards, done, info = self._env.step(action) def load(self): diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index f28916c2..0c787e87 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Enumerations for APE.""" -from enum import Enum +from enum import Enum, IntEnum class NodeType(Enum): @@ -172,3 +172,13 @@ class LinkStatus(Enum): MEDIUM = 2 HIGH = 3 OVERLOAD = 4 + + +class OutputVerboseLevel(IntEnum): + """The Agent output verbosity level.""" + NONE = 0 + "No Output" + INFO = 1 + "Info Messages" + ALL = 2 + "All Messages" diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index ebee7f77..703f37f5 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -54,6 +54,13 @@ agent_load_file: C:\[Path]\[agent_saved_filename.zip] # The high value for the observation space observation_space_high_value: 1000000000 +# The Agent output verbosity level: +# Options are: +# "NONE" (No Output) +# "INFO" (Info Messages) +# "ALL" (All Messages) +output_verbose_level: INFO + # Reward values # Generic all_ok: 0 diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 4fd2142e..49a33d6e 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,21 +1,63 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from pathlib import Path -from typing import Final +from typing import Final, Union, Dict, Any import networkx +import yaml from primaite import USERS_CONFIG_DIR, getLogger _LOGGER = getLogger(__name__) -_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" +_EXAMPLE_LAY_DOWN: Final[ + Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" -# class LayDownConfig: -# network: networkx.Graph -# POL -# EIR -# ACL +def convert_legacy_lay_down_config_dict( + legacy_config_dict: Dict[str, Any] +) -> Dict[str, Any]: + """ + Convert a legacy lay down config dict to the new format. + + :param legacy_config_dict: A legacy lay down config dict. + """ + _LOGGER.warning("Legacy lay down config conversion not yet implemented") + return legacy_config_dict + + +def load( + file_path: Union[str, Path], + legacy_file: bool = False +) -> Dict: + """ + Read in a lay down config yaml file. + + :param file_path: The config file path. + :param legacy_file: True if the config file is legacy format, otherwise + False. + :return: The lay down config as a dict. + :raises ValueError: If the file_path does not exist. + """ + if not isinstance(file_path, Path): + file_path = Path(file_path) + if file_path.exists(): + with open(file_path, "r") as file: + config = yaml.safe_load(file) + _LOGGER.debug(f"Loading lay down config file: {file_path}") + if legacy_file: + try: + config = convert_legacy_lay_down_config_dict(config) + except KeyError: + msg = ( + f"Failed to convert lay down config file {file_path} " + f"from legacy format. Attempting to use file as is." + ) + _LOGGER.error(msg) + return config + msg = f"Cannot load the lay down config as it does not exist: {file_path}" + _LOGGER.error(msg) + raise ValueError(msg) + def ddos_basic_one_config_path() -> Path: """ diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index c2cb8db9..0d39f9c4 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -10,11 +10,27 @@ import yaml from primaite import USERS_CONFIG_DIR, getLogger from primaite.common.enums import DeepLearningFramework from primaite.common.enums import ActionType, RedAgentIdentifier, \ - AgentFramework, SessionType + AgentFramework, SessionType, OutputVerboseLevel _LOGGER = getLogger(__name__) -_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" +_EXAMPLE_TRAINING: Final[ + Path] = USERS_CONFIG_DIR / "example_config" / "training" + + +def main_training_config_path() -> Path: + """ + The path to the example training_config_main.yaml file. + + :return: The file path. + """ + path = _EXAMPLE_TRAINING / "training_config_main.yaml" + if not path.exists(): + msg = "Example config not found. Please run 'primaite setup'" + _LOGGER.critical(msg) + raise FileNotFoundError(msg) + + return path @dataclass() @@ -24,44 +40,47 @@ class TrainingConfig: "The AgentFramework" deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF - "The DeepLearningFramework." + "The DeepLearningFramework" red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO - "The RedAgentIdentifier.." + "The RedAgentIdentifier" action_type: ActionType = ActionType.ANY - "The ActionType to use." + "The ActionType to use" num_episodes: int = 10 - "The number of episodes to train over." + "The number of episodes to train over" num_steps: int = 256 - "The number of steps in an episode." + "The number of steps in an episode" checkpoint_every_n_episodes: int = 5 - "The agent will save a checkpoint every n episodes." + "The agent will save a checkpoint every n episodes" observation_space: dict = field( default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]} ) - "The observation space config dict." + "The observation space config dict" time_delay: int = 10 - "The delay between steps (ms). Applies to generic agents only." + "The delay between steps (ms). Applies to generic agents only" # file session_type: SessionType = SessionType.TRAINING - "The type of PrimAITE session to run." + "The type of PrimAITE session to run" load_agent: str = False - "Determine whether to load an agent from file." + "Determine whether to load an agent from file" agent_load_file: Optional[str] = None - "File path and file name of agent if you're loading one in." + "File path and file name of agent if you're loading one in" # Environment observation_space_high_value: int = 1000000000 - "The high value for the observation space." + "The high value for the observation space" + + output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO + "The Agent output verbosity level" # Reward values # Generic @@ -126,28 +145,28 @@ class TrainingConfig: # Patching / Reset durations os_patching_duration: int = 5 - "The time taken to patch the OS." + "The time taken to patch the OS" node_reset_duration: int = 5 - "The time taken to reset a node (hardware)." + "The time taken to reset a node (hardware)" node_booting_duration: int = 3 - "The Time taken to turn on the node." + "The Time taken to turn on the node" node_shutdown_duration: int = 2 - "The time taken to turn off the node." + "The time taken to turn off the node" service_patching_duration: int = 5 - "The time taken to patch a service." + "The time taken to patch a service" file_system_repairing_limit: int = 5 - "The time take to repair the file system." + "The time take to repair the file system" file_system_restoring_limit: int = 5 - "The time take to restore the file system." + "The time take to restore the file system" file_system_scanning_limit: int = 5 - "The time taken to scan the file system." + "The time taken to scan the file system" @classmethod def from_dict( @@ -157,9 +176,10 @@ class TrainingConfig: field_enum_map = { "agent_framework": AgentFramework, "deep_learning_framework": DeepLearningFramework, - "red_agent_identifier": RedAgentIdentifier, - "action_type": ActionType, - "session_type": SessionType + "red_agent_identifier": RedAgentIdentifier, + "action_type": ActionType, + "session_type": SessionType, + "output_verbose_level": OutputVerboseLevel } for field, enum_class in field_enum_map.items(): @@ -178,28 +198,19 @@ class TrainingConfig: """ data = self.__dict__ if json_serializable: + data["agent_framework"] = self.agent_framework.value + data["deep_learning_framework"] = self.deep_learning_framework.value + data["red_agent_identifier"] = self.red_agent_identifier.value data["action_type"] = self.action_type.value + data["output_verbose_level"] = self.output_verbose_level.value return data -def main_training_config_path() -> Path: - """ - The path to the example training_config_main.yaml file. - - :return: The file path. - """ - path = _EXAMPLE_TRAINING / "training_config_main.yaml" - if not path.exists(): - msg = "Example config not found. Please run 'primaite setup'" - _LOGGER.critical(msg) - raise FileNotFoundError(msg) - - return path - - -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load( + file_path: Union[str, Path], + legacy_file: bool = False +) -> TrainingConfig: """ Read in a training config yaml file. @@ -246,7 +257,8 @@ def convert_legacy_training_config_dict( agent_framework: AgentFramework = AgentFramework.SB3, red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO, action_type: ActionType = ActionType.ANY, - num_steps: int = 256 + num_steps: int = 256, + output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -260,13 +272,16 @@ def convert_legacy_training_config_dict( don't have action_type values. :param num_steps: The number of steps to set as legacy training configs don't have num_steps values. + :param output_verbose_level: The agent output verbose level to use as + legacy training configs don't have output_verbose_level values. :return: The converted training config dict. """ config_dict = { "agent_framework": agent_framework.name, "red_agent_identifier": red_agent_identifier.name, "action_type": action_type.name, - "num_steps": num_steps + "num_steps": num_steps, + "output_verbose_level": output_verbose_level } for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e0cfb119..68209713 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -435,7 +435,6 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes - if self.training_config.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: diff --git a/src/primaite/main.py b/src/primaite/main.py index 842b9259..8619dc57 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,305 +1,229 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -""" -The main PrimAITE session runner module. - -TODO: This will eventually be refactored out into a proper Session class. -TODO: The passing about of session_dir and timestamp_str is temporary and - will be cleaned up once we move to a proper Session class. -""" -import argparse -import json -import time -from datetime import datetime -from pathlib import Path -from typing import Final, Union -from uuid import uuid4 - -from stable_baselines3 import A2C, PPO -from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.ppo import MlpPolicy as PPOMlp - -from primaite import SESSIONS_DIR, getLogger -from primaite.config.training_config import TrainingConfig -from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file - -_LOGGER = getLogger(__name__) - - -def run_generic(env: Primaite, config_values: TrainingConfig): - """ - Run against a generic agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - """ - for episode in range(0, config_values.num_episodes): - env.reset() - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - action = env.action_space.sample() - - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - - # Reset the environment at the end of the episode - - env.close() - - -def run_stable_baselines3_ppo( - env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -): - """ - Run against a stable_baselines3 PPO agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if config_values.load_agent: - try: - agent = PPO.load( - config_values.agent_load_file, - env, - verbose=0, - n_steps=config_values.num_steps, - ) - except Exception: - print( - "ERROR: Could not load agent at location: " - + config_values.agent_load_file - ) - _LOGGER.error("Could not load agent") - _LOGGER.error("Exception occured", exc_info=True) - else: - agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) - - if config_values.session_type == "TRAINING": - # We're in a training session - print("Starting training session...") - _LOGGER.debug("Starting training session...") - for episode in range(config_values.num_episodes): - agent.learn(total_timesteps=config_values.num_steps) - _save_agent(agent, session_path, timestamp_str) - else: - # Default to being in an evaluation session - print("Starting evaluation session...") - _LOGGER.debug("Starting evaluation session...") - evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) - - env.close() - - -def _write_session_metadata_file( - session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite -): - """ - Write the ``session_metadata.json`` file. - - Creates a ``session_metadata.json`` in the ``session_dir`` directory - and adds the following key/value pairs: - - - uuid: The UUID assigned to the session upon instantiation. - - start_datetime: The date & time the session started in iso format. - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - - env: - - training_config: - - All training config items - - lay_down_config: - - All lay down config items - - """ - metadata_dict = { - "uuid": uuid, - "start_datetime": session_timestamp.isoformat(), - "end_datetime": None, - "total_episodes": None, - "total_time_steps": None, - "env": { - "training_config": env.training_config.to_dict(json_serializable=True), - "lay_down_config": env.lay_down_config, - }, - } - filepath = session_dir / "session_metadata.json" - _LOGGER.debug(f"Writing Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - - -def _update_session_metadata_file(session_dir: Path, env: Primaite): - """ - Update the ``session_metadata.json`` file. - - Updates the `session_metadata.json`` in the ``session_dir`` directory - with the following key/value pairs: - - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - """ - with open(session_dir / "session_metadata.json", "r") as file: - metadata_dict = json.load(file) - - metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = env.episode_count - metadata_dict["total_time_steps"] = env.total_step_count - - filepath = session_dir / "session_metadata.json" - _LOGGER.debug(f"Updating Session Metadata file: {filepath}") - with open(filepath, "w") as file: - json.dump(metadata_dict, file) - - -def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): - """ - Persist an agent. - - Only works for stable baselines3 agents at present. - - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if not isinstance(agent, OnPolicyAlgorithm): - msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." - _LOGGER.error(msg) - else: - filepath = session_path / f"agent_saved_{timestamp_str}" - agent.save(filepath) - _LOGGER.debug(f"Trained agent saved as: {filepath}") - - -def _get_session_path(session_timestamp: datetime) -> Path: - """ - Get the directory path the session will output to. - - This is set in the format of: - ~/primaite/sessions//_. - - :param session_timestamp: This is the datetime that the session started. - :return: The session directory path. - """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = SESSIONS_DIR / date_dir / session_dir - session_path.mkdir(exist_ok=True, parents=True) - - return session_path - - -def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): - """Run the PrimAITE Session. - - :param training_config_path: The training config filepath. - :param lay_down_config_path: The lay down config filepath. - """ - # Welcome message - print("Welcome to the Primary-level AI Training Environment (PrimAITE)") - uuid = str(uuid4()) - session_timestamp: Final[datetime] = datetime.now() - session_dir = _get_session_path(session_timestamp) - timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - - print(f"The output directory for this session is: {session_dir}") - - # Create a list of transactions - # A transaction is an object holding the: - # - episode # - # - step # - # - initial observation space - # - action - # - reward - # - new observation space - transaction_list = [] - - # Create the Primaite environment - env = Primaite( - training_config_path=training_config_path, - lay_down_config_path=lay_down_config_path, - transaction_list=transaction_list, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Writing Session Metadata file...") - - _write_session_metadata_file( - session_dir=session_dir, uuid=uuid, session_timestamp=session_timestamp, env=env - ) - - config_values = env.training_config - - # Get the number of steps (which is stored in the child config file) - config_values.num_steps = env.episode_steps - - # Run environment against an agent - if config_values.agent_identifier == "GENERIC": - run_generic(env=env, config_values=config_values) - elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": - run_stable_baselines3_ppo( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": - run_stable_baselines3_a2c( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Session finished") - _LOGGER.debug("Session finished") - - print("Saving transaction logs...") - write_transaction_to_file( - transaction_list=transaction_list, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Updating Session Metadata file...") - _update_session_metadata_file(session_dir=session_dir, env=env) - - print("Finished") - _LOGGER.debug("Finished") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--tc") - parser.add_argument("--ldc") - args = parser.parse_args() - if not args.tc: - _LOGGER.error( - "Please provide a training config file using the --tc " "argument" - ) - if not args.ldc: - _LOGGER.error( - "Please provide a lay down config file using the --ldc " "argument" - ) - run(training_config_path=args.tc, lay_down_config_path=args.ldc) - - +# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +# """ +# The main PrimAITE session runner module. +# +# TODO: This will eventually be refactored out into a proper Session class. +# TODO: The passing about of session_path and timestamp_str is temporary and +# will be cleaned up once we move to a proper Session class. +# """ +# import argparse +# import json +# import time +# from datetime import datetime +# from pathlib import Path +# from typing import Final, Union +# from uuid import uuid4 +# +# from stable_baselines3 import A2C, PPO +# from stable_baselines3.common.evaluation import evaluate_policy +# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +# from stable_baselines3.ppo import MlpPolicy as PPOMlp +# +# from primaite import SESSIONS_DIR, getLogger +# from primaite.config.training_config import TrainingConfig +# from primaite.environment.primaite_env import Primaite +# from primaite.transactions.transactions_to_file import \ +# write_transaction_to_file +# +# _LOGGER = getLogger(__name__) +# +# +# def run_generic(env: Primaite, config_values: TrainingConfig): +# """ +# Run against a generic agent. +# +# :param env: An instance of +# :class:`~primaite.environment.primaite_env.Primaite`. +# :param config_values: An instance of +# :class:`~primaite.config.training_config.TrainingConfig`. +# """ +# for episode in range(0, config_values.num_episodes): +# env.reset() +# for step in range(0, config_values.num_steps): +# # Send the observation space to the agent to get an action +# # TEMP - random action for now +# # action = env.blue_agent_action(obs) +# action = env.action_space.sample() +# +# # Run the simulation step on the live environment +# obs, reward, done, info = env.step(action) +# +# # Break if done is True +# if done: +# break +# +# # Introduce a delay between steps +# time.sleep(config_values.time_delay / 1000) +# +# # Reset the environment at the end of the episode +# +# env.close() +# +# +# def run_stable_baselines3_ppo( +# env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str +# ): +# """ +# Run against a stable_baselines3 PPO agent. +# +# :param env: An instance of +# :class:`~primaite.environment.primaite_env.Primaite`. +# :param config_values: An instance of +# :class:`~primaite.config.training_config.TrainingConfig`. +# :param session_path: The directory path the session is writing to. +# :param timestamp_str: The session timestamp in the format: +# _. +# """ +# if config_values.load_agent: +# try: +# agent = PPO.load( +# config_values.agent_load_file, +# env, +# verbose=0, +# n_steps=config_values.num_steps, +# ) +# except Exception: +# print( +# "ERROR: Could not load agent at location: " +# + config_values.agent_load_file +# ) +# _LOGGER.error("Could not load agent") +# _LOGGER.error("Exception occured", exc_info=True) +# else: +# agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) +# +# if config_values.session_type == "TRAINING": +# # We're in a training session +# print("Starting training session...") +# _LOGGER.debug("Starting training session...") +# for episode in range(config_values.num_episodes): +# agent.learn(total_timesteps=config_values.num_steps) +# _save_agent(agent, session_path, timestamp_str) +# else: +# # Default to being in an evaluation session +# print("Starting evaluation session...") +# _LOGGER.debug("Starting evaluation session...") +# evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) +# +# env.close() +# +# +# +# +# def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): +# """ +# Persist an agent. +# +# Only works for stable baselines3 agents at present. +# +# :param session_path: The directory path the session is writing to. +# :param timestamp_str: The session timestamp in the format: +# _. +# """ +# if not isinstance(agent, OnPolicyAlgorithm): +# msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." +# _LOGGER.error(msg) +# else: +# filepath = session_path / f"agent_saved_{timestamp_str}" +# agent.save(filepath) +# _LOGGER.debug(f"Trained agent saved as: {filepath}") +# +# +# +# +# def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): +# """Run the PrimAITE Session. +# +# :param training_config_path: The training config filepath. +# :param lay_down_config_path: The lay down config filepath. +# """ +# # Welcome message +# print("Welcome to the Primary-level AI Training Environment (PrimAITE)") +# uuid = str(uuid4()) +# session_timestamp: Final[datetime] = datetime.now() +# session_path = _get_session_path(session_timestamp) +# timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") +# +# print(f"The output directory for this session is: {session_path}") +# +# # Create a list of transactions +# # A transaction is an object holding the: +# # - episode # +# # - step # +# # - initial observation space +# # - action +# # - reward +# # - new observation space +# transaction_list = [] +# +# # Create the Primaite environment +# env = Primaite( +# training_config_path=training_config_path, +# lay_down_config_path=lay_down_config_path, +# transaction_list=transaction_list, +# session_path=session_path, +# timestamp_str=timestamp_str, +# ) +# +# print("Writing Session Metadata file...") +# +# _write_session_metadata_file( +# session_path=session_path, uuid=uuid, session_timestamp=session_timestamp, env=env +# ) +# +# config_values = env.training_config +# +# # Get the number of steps (which is stored in the child config file) +# config_values.num_steps = env.episode_steps +# +# # Run environment against an agent +# if config_values.agent_identifier == "GENERIC": +# run_generic(env=env, config_values=config_values) +# elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": +# run_stable_baselines3_ppo( +# env=env, +# config_values=config_values, +# session_path=session_path, +# timestamp_str=timestamp_str, +# ) +# elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": +# run_stable_baselines3_a2c( +# env=env, +# config_values=config_values, +# session_path=session_path, +# timestamp_str=timestamp_str, +# ) +# +# print("Session finished") +# _LOGGER.debug("Session finished") +# +# print("Saving transaction logs...") +# write_transaction_to_file( +# transaction_list=transaction_list, +# session_path=session_path, +# timestamp_str=timestamp_str, +# ) +# +# print("Updating Session Metadata file...") +# _update_session_metadata_file(session_path=session_path, env=env) +# +# print("Finished") +# _LOGGER.debug("Finished") +# +# +# if __name__ == "__main__": +# parser = argparse.ArgumentParser() +# parser.add_argument("--tc") +# parser.add_argument("--ldc") +# args = parser.parse_args() +# if not args.tc: +# _LOGGER.error( +# "Please provide a training config file using the --tc " "argument" +# ) +# if not args.ldc: +# _LOGGER.error( +# "Please provide a lay down config file using the --ldc " "argument" +# ) +# run(training_config_path=args.tc, lay_down_config_path=args.ldc) +# +# diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 0efc0acf..8f3380c8 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -3,12 +3,16 @@ from __future__ import annotations import json from datetime import datetime from pathlib import Path -from typing import Final, Optional, Union +from typing import Final, Optional, Union, Dict from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR +from primaite.agents.agent import AgentSessionABC +from primaite.agents.rllib import RLlibPPO +from primaite.agents.sb3 import SB3PPO from primaite.common.enums import AgentFramework, RedAgentIdentifier, \ ActionType +from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite @@ -26,8 +30,8 @@ def _get_session_path(session_timestamp: datetime) -> Path: :return: The session directory path. """ date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = SESSIONS_DIR / date_dir / session_dir + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = SESSIONS_DIR / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) _LOGGER.debug(f"Created PrimAITE Session path: {session_path}") @@ -45,211 +49,100 @@ class PrimaiteSession: if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path + self._training_config: Final[TrainingConfig] = training_config.load( + self._training_config_path + ) if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path - - self._auto: Final[bool] = auto - - self._uuid: str = str(uuid4()) - self._session_timestamp: Final[datetime] = datetime.now() - self._session_path: Final[Path] = _get_session_path( - self._session_timestamp + self._lay_down_config: Dict = lay_down_config.load( + self._lay_down_config_path ) - self._timestamp_str: Final[str] = self._session_timestamp.strftime( - "%Y-%m-%d_%H-%M-%S") - self._metadata_path = self._session_path / "session_metadata.json" - - self._env = None - self._training_config: TrainingConfig - self._can_learn: bool = False - _LOGGER.debug("") + self._auto: bool = auto + self._agent_session: AgentSessionABC = None # noqa if self._auto: self.setup() self.learn() - @property - def uuid(self): - """The session UUID.""" - return self._uuid - - def _setup_primaite_env(self, transaction_list: Optional[list] = None): - if not transaction_list: - transaction_list = [] - self._env: Primaite = Primaite( - training_config_path=self._training_config_path, - lay_down_config_path=self._lay_down_config_path, - transaction_list=transaction_list, - session_path=self._session_path, - timestamp_str=self._timestamp_str - ) - self._training_config: TrainingConfig = self._env.training_config - - def _write_session_metadata_file(self): - """ - Write the ``session_metadata.json`` file. - - Creates a ``session_metadata.json`` in the ``session_dir`` directory - and adds the following key/value pairs: - - - uuid: The UUID assigned to the session upon instantiation. - - start_datetime: The date & time the session started in iso format. - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - - env: - - training_config: - - All training config items - - lay_down_config: - - All lay down config items - """ - metadata_dict = { - "uuid": self._uuid, - "start_datetime": self._session_timestamp.isoformat(), - "end_datetime": None, - "total_episodes": None, - "total_time_steps": None, - "env": { - "training_config": self._env.training_config.to_dict( - json_serializable=True - ), - "lay_down_config": self._env.lay_down_config, - }, - } - _LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}") - with open(self._metadata_path, "w") as file: - json.dump(metadata_dict, file) - - def _update_session_metadata_file(self): - """ - Update the ``session_metadata.json`` file. - - Updates the `session_metadata.json`` in the ``session_dir`` directory - with the following key/value pairs: - - - end_datetime: NULL. - - total_episodes: NULL. - - total_time_steps: NULL. - """ - with open(self._metadata_path, "r") as file: - metadata_dict = json.load(file) - - metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._env.episode_count - metadata_dict["total_time_steps"] = self._env.total_step_count - - _LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}") - with open(self._metadata_path, "w") as file: - json.dump(metadata_dict, file) - def setup(self): - self._setup_primaite_env() - self._can_learn = True - pass + if self._training_config.agent_framework == AgentFramework.NONE: + if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM: + # Stochastic Random Agent + raise NotImplementedError + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED: + if self._training_config.action_type == ActionType.NODE: + # Deterministic Hardcoded Agent with Node Action Space + raise NotImplementedError + + elif self._training_config.action_type == ActionType.ACL: + # Deterministic Hardcoded Agent with ACL Action Space + raise NotImplementedError + + elif self._training_config.action_type == ActionType.ANY: + # Deterministic Hardcoded Agent with ANY Action Space + raise NotImplementedError + + else: + # Invalid RedAgentIdentifier ActionType combo + pass + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + + elif self._training_config.agent_framework == AgentFramework.SB3: + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + # Stable Baselines3/Proximal Policy Optimization + self._agent_session = SB3PPO( + self._training_config_path, + self._lay_down_config_path + ) + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + # Stable Baselines3/Advantage Actor Critic + raise NotImplementedError + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + + elif self._training_config.agent_framework == AgentFramework.RLLIB: + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + # Ray RLlib/Proximal Policy Optimization + self._agent_session = RLlibPPO( + self._training_config_path, + self._lay_down_config_path + ) + + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + # Ray RLlib/Advantage Actor Critic + raise NotImplementedError + + else: + # Invalid AgentFramework RedAgentIdentifier combo + pass + else: + # Invalid AgentFramework + pass def learn( self, - time_steps: Optional[int], - episodes: Optional[int], - iterations: Optional[int], + time_steps: Optional[int] = None, + episodes: Optional[int] = None, **kwargs ): - if self._can_learn: - # Run environment against an agent - if self._training_config.agent_framework == AgentFramework.NONE: - if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM: - # Stochastic Random Agent - run_generic(env=env, config_values=config_values) - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED: - if self._training_config.action_type == ActionType.NODE: - # Deterministic Hardcoded Agent with Node Action Space - pass - - elif self._training_config.action_type == ActionType.ACL: - # Deterministic Hardcoded Agent with ACL Action Space - pass - - elif self._training_config.action_type == ActionType.ANY: - # Deterministic Hardcoded Agent with ANY Action Space - pass - - else: - # Invalid RedAgentIdentifier ActionType combo - pass - - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass - - elif self._training_config.agent_framework == AgentFramework.SB3: - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: - # Stable Baselines3/Proximal Policy Optimization - run_stable_baselines3_ppo( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: - # Stable Baselines3/Advantage Actor Critic - run_stable_baselines3_a2c( - env=env, - config_values=config_values, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass - - elif self._training_config.agent_framework == AgentFramework.RLLIB: - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: - # Ray RLlib/Proximal Policy Optimization - pass - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: - # Ray RLlib/Advantage Actor Critic - pass - - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass - else: - # Invalid AgentFramework - pass - - print("Session finished") - _LOGGER.debug("Session finished") - - print("Saving transaction logs...") - write_transaction_to_file( - transaction_list=transaction_list, - session_path=session_dir, - timestamp_str=timestamp_str, - ) - - print("Updating Session Metadata file...") - _update_session_metadata_file(session_dir=session_dir, env=env) - - print("Finished") - _LOGGER.debug("Finished") + self._agent_session.learn(time_steps, episodes, **kwargs) def evaluate( self, - time_steps: Optional[int], - episodes: Optional[int], + time_steps: Optional[int] = None, + episodes: Optional[int] = None, **kwargs ): - pass - - def export(self): - pass + self._agent_session.evaluate(time_steps, episodes, **kwargs) @classmethod def import_agent( diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 11e68af8..24581597 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -108,5 +108,6 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st csv_writer.writerow(csv_data) csv_file.close() + _LOGGER.debug("Finished writing transactions") except Exception: _LOGGER.error("Could not save the transaction file", exc_info=True) diff --git a/tests/conftest.py b/tests/conftest.py index f1411ba9..1bad5db0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,8 +19,8 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path: :return: The session directory path. """ date_dir = session_timestamp.strftime("%Y-%m-%d") - session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_dir + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) return session_path From 03ae4884e00019daa10e107c08954bd0bd6519cc Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 19 Jun 2023 21:53:25 +0100 Subject: [PATCH 05/43] #917 - Almost there. All output files being writen for SB3/RLLIB PPO & A2C. Just need to bring in the hardcoded agents then update the testa and docs. --- src/primaite/VERSION | 2 +- src/primaite/agents/agent.py | 18 +- src/primaite/agents/rllib.py | 97 ++++- src/primaite/agents/sb3.py | 34 +- src/primaite/common/enums.py | 4 +- .../lay_down_config_1_DDOS_basic.yaml | 4 - .../lay_down_config_2_DDOS_basic.yaml | 4 - .../lay_down_config_3_DOS_very_basic.yaml | 4 - .../training/training_config_main.yaml | 2 +- src/primaite/config/training_config.py | 2 +- src/primaite/environment/primaite_env.py | 17 +- src/primaite/main.py | 391 ++++++++---------- src/primaite/primaite_session.py | 62 +-- .../transactions/transactions_to_file.py | 8 +- 14 files changed, 321 insertions(+), 328 deletions(-) diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 4111d137..0da493b5 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0rc1 +2.0.0b1 \ No newline at end of file diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 58158dcb..34ad0adb 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -5,19 +5,18 @@ from pathlib import Path from typing import Optional, Final, Dict, Union, List from uuid import uuid4 -from primaite import getLogger +from primaite import getLogger, SESSIONS_DIR from primaite.common.enums import OutputVerboseLevel from primaite.config import lay_down_config from primaite.config import training_config from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file + _LOGGER = getLogger(__name__) -def _get_temp_session_path(session_timestamp: datetime) -> Path: +def _get_session_path(session_timestamp: datetime) -> Path: """ Get a temp directory session path the test session will output to. @@ -26,7 +25,7 @@ def _get_temp_session_path(session_timestamp: datetime) -> Path: """ date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path("./") / date_dir / session_path + session_path = SESSIONS_DIR / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) return session_path @@ -57,16 +56,16 @@ class AgentSessionABC(ABC): self._env: Primaite self._agent = None - self._transaction_list: List[Dict] = [] self._can_learn: bool = False self._can_evaluate: bool = False self._uuid = str(uuid4()) self.session_timestamp: datetime = datetime.now() "The session timestamp" - self.session_path = _get_temp_session_path(self.session_timestamp) + self.session_path = _get_session_path(self.session_timestamp) "The Session path" self.checkpoints_path = self.session_path / "checkpoints" + self.checkpoints_path.mkdir(parents=True, exist_ok=True) "The Session checkpoints path" self.timestamp_str = self.session_timestamp.strftime( @@ -167,11 +166,6 @@ class AgentSessionABC(ABC): ): if self._can_learn: _LOGGER.debug("Writing transactions") - write_transaction_to_file( - transaction_list=self._transaction_list, - session_path=self.session_path, - timestamp_str=self.timestamp_str, - ) self._update_session_metadata_file() self._can_evaluate = True diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 80318499..67ba6213 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,18 +1,23 @@ -import glob -import time -from enum import Enum +import json +from datetime import datetime from pathlib import Path -from typing import Union, Optional +from pathlib import Path +from typing import Optional from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.ppo import PPOConfig +from ray.rllib.algorithms.a2c import A2CConfig +from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env +from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.config import training_config +from primaite.common.enums import AgentFramework, RedAgentIdentifier from primaite.environment.primaite_env import Primaite +_LOGGER = getLogger(__name__) + def _env_creator(env_config): return Primaite( training_config_path=env_config["training_config_path"], @@ -23,7 +28,17 @@ def _env_creator(env_config): ) -class RLlibPPO(AgentSessionABC): +def _custom_log_creator(session_path: Path): + logdir = session_path / "ray_results" + logdir.mkdir(parents=True, exist_ok=True) + + def logger_creator(config): + return UnifiedLogger(config, logdir, loggers=None) + + return logger_creator + + +class RLlibAgent(AgentSessionABC): def __init__( self, @@ -31,17 +46,63 @@ class RLlibPPO(AgentSessionABC): lay_down_config_path ): super().__init__(training_config_path, lay_down_config_path) - self._ppo_config: PPOConfig + if not self._training_config.agent_framework == AgentFramework.RLLIB: + msg = (f"Expected RLLIB agent_framework, " + f"got {self._training_config.agent_framework}") + _LOGGER.error(msg) + raise ValueError(msg) + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + self._agent_config_class = PPOConfig + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + self._agent_config_class = A2CConfig + else: + msg = ("Expected PPO or A2C red_agent_identifier, " + f"got {self._training_config.red_agent_identifier.value}") + _LOGGER.error(msg) + raise ValueError(msg) + self._agent_config: PPOConfig + self._current_result: dict self._setup() - self._agent.save() + _LOGGER.debug( + f"Created {self.__class__.__name__} using: " + f"agent_framework={self._training_config.agent_framework}, " + f"red_agent_identifier=" + f"{self._training_config.red_agent_identifier}, " + f"deep_learning_framework=" + f"{self._training_config.deep_learning_framework}" + ) + + def _update_session_metadata_file(self): + """ + Update the ``session_metadata.json`` file. + + Updates the `session_metadata.json`` in the ``session_path`` directory + with the following key/value pairs: + + - end_datetime: The date & time the session ended in iso format. + - total_episodes: The total number of training episodes completed. + - total_time_steps: The total number of training time steps completed. + """ + with open(self.session_path / "session_metadata.json", "r") as file: + metadata_dict = json.load(file) + + metadata_dict["end_datetime"] = datetime.now().isoformat() + metadata_dict["total_episodes"] = self._current_result["episodes_total"] + metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] + + filepath = self.session_path / "session_metadata.json" + _LOGGER.debug(f"Updating Session Metadata file: {filepath}") + with open(filepath, "w") as file: + json.dump(metadata_dict, file) + _LOGGER.debug("Finished updating session metadata file") def _setup(self): super()._setup() register_env("primaite", _env_creator) - self._ppo_config = PPOConfig() + self._agent_config = self._agent_config_class() - self._ppo_config.environment( + self._agent_config.environment( env="primaite", env_config=dict( training_config_path=self._training_config_path, @@ -52,19 +113,21 @@ class RLlibPPO(AgentSessionABC): ) ) - self._ppo_config.training( + self._agent_config.training( train_batch_size=self._training_config.num_steps ) - self._ppo_config.framework( - framework=self._training_config.deep_learning_framework.value + self._agent_config.framework( + framework=self._training_config.deep_learning_framework ) - self._ppo_config.rollouts( + self._agent_config.rollouts( num_rollout_workers=1, num_envs_per_worker=1, horizon=self._training_config.num_steps ) - self._agent: Algorithm = self._ppo_config.build() + self._agent: Algorithm = self._agent_config.build( + logger_creator=_custom_log_creator(self.session_path) + ) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes @@ -84,8 +147,8 @@ class RLlibPPO(AgentSessionABC): ): # Temporarily override train_batch_size and horizon if time_steps: - self._ppo_config.train_batch_size = time_steps - self._ppo_config.horizon = time_steps + self._agent_config.train_batch_size = time_steps + self._agent_config.horizon = time_steps if not episodes: episodes = self._training_config.num_episodes diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 6e6d8a5d..3cd2e50a 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,23 +1,48 @@ from typing import Optional import numpy as np -from stable_baselines3 import PPO +from stable_baselines3 import PPO, A2C +from primaite import getLogger from primaite.agents.agent import AgentSessionABC +from primaite.common.enums import RedAgentIdentifier, AgentFramework from primaite.environment.primaite_env import Primaite from stable_baselines3.ppo import MlpPolicy as PPOMlp +_LOGGER = getLogger(__name__) -class SB3PPO(AgentSessionABC): + +class SB3Agent(AgentSessionABC): def __init__( self, training_config_path, lay_down_config_path ): super().__init__(training_config_path, lay_down_config_path) + if not self._training_config.agent_framework == AgentFramework.SB3: + msg = (f"Expected SB3 agent_framework, " + f"got {self._training_config.agent_framework}") + _LOGGER.error(msg) + raise ValueError(msg) + if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + self._agent_class = PPO + elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + self._agent_class = A2C + else: + msg = ("Expected PPO or A2C red_agent_identifier, " + f"got {self._training_config.red_agent_identifier.value}") + _LOGGER.error(msg) + raise ValueError(msg) + self._tensorboard_log_path = self.session_path / "tensorboard_logs" self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) self._setup() + _LOGGER.debug( + f"Created {self.__class__.__name__} using: " + f"agent_framework={self._training_config.agent_framework}, " + f"red_agent_identifier=" + f"{self._training_config.red_agent_identifier}" + ) def _setup(self): super()._setup() @@ -28,10 +53,10 @@ class SB3PPO(AgentSessionABC): session_path=self.session_path, timestamp_str=self.timestamp_str ) - self._agent = PPO( + self._agent = self._agent_class( PPOMlp, self._env, - verbose=1, + verbose=self._training_config.output_verbose_level, n_steps=self._training_config.num_steps, tensorboard_log=self._tensorboard_log_path ) @@ -65,6 +90,7 @@ class SB3PPO(AgentSessionABC): for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() + self._env.close() super().learn() def evaluate( diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 0c787e87..89bfd737 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -80,13 +80,13 @@ class Protocol(Enum): class SessionType(Enum): - "The type of PrimAITE Session to be run." + """The type of PrimAITE Session to be run.""" TRAINING = 1 EVALUATION = 2 BOTH = 3 -class VerboseLevel(Enum): +class VerboseLevel(IntEnum): """PrimAITE Session Output verbose level.""" NO_OUTPUT = 0 INFO = 1 diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml index f7c1e372..3f0c546a 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_1_DDOS_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 128 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml index e4a3385d..39bf7dac 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_2_DDOS_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 128 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml b/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml index 9f37a6f0..619a0d35 100644 --- a/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml +++ b/src/primaite/config/_package_data/lay_down/lay_down_config_3_DOS_very_basic.yaml @@ -1,7 +1,3 @@ -- item_type: ACTIONS - type: NODE -- item_type: STEPS - steps: 256 - item_type: PORTS ports_list: - port: '80' diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 703f37f5..d7b4db98 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -59,7 +59,7 @@ observation_space_high_value: 1000000000 # "NONE" (No Output) # "INFO" (Info Messages) # "ALL" (All Messages) -output_verbose_level: INFO +output_verbose_level: NONE # Reward values # Generic diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 0d39f9c4..4695f2f5 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -185,7 +185,6 @@ class TrainingConfig: for field, enum_class in field_enum_map.items(): if field in config_dict: config_dict[field] = enum_class[config_dict[field]] - return TrainingConfig(**config_dict) def to_dict(self, json_serializable: bool = True): @@ -203,6 +202,7 @@ class TrainingConfig: data["red_agent_identifier"] = self.red_agent_identifier.value data["action_type"] = self.action_type.value data["output_verbose_level"] = self.output_verbose_level.value + data["session_type"] = self.session_type.value return data diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 68209713..0876f070 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -45,6 +45,8 @@ from primaite.pol.ier import IER from primaite.pol.red_agent_pol import apply_red_agent_iers, \ apply_red_agent_node_pol from primaite.transactions.transaction import Transaction +from primaite.transactions.transactions_to_file import \ + write_transaction_to_file _LOGGER = logging.getLogger(__name__) _LOGGER.setLevel(logging.INFO) @@ -407,10 +409,19 @@ class Primaite(Env): # Return return self.env_obs, reward, done, self.step_info - def __close__(self): - """Override close function.""" - self.csv_file.close() + def close(self): + self.__close__() + def __close__(self): + """ + Override close function + """ + write_transaction_to_file( + self.transaction_list, + self.session_path, + self.timestamp_str + ) + self.csv_file.close() def init_acl(self): """Initialise the Access Control List.""" self.acl.remove_all_rules() diff --git a/src/primaite/main.py b/src/primaite/main.py index 8619dc57..34134ba2 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,229 +1,162 @@ -# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -# """ -# The main PrimAITE session runner module. -# -# TODO: This will eventually be refactored out into a proper Session class. -# TODO: The passing about of session_path and timestamp_str is temporary and -# will be cleaned up once we move to a proper Session class. -# """ -# import argparse -# import json -# import time -# from datetime import datetime -# from pathlib import Path -# from typing import Final, Union -# from uuid import uuid4 -# -# from stable_baselines3 import A2C, PPO -# from stable_baselines3.common.evaluation import evaluate_policy -# from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -# from stable_baselines3.ppo import MlpPolicy as PPOMlp -# -# from primaite import SESSIONS_DIR, getLogger -# from primaite.config.training_config import TrainingConfig -# from primaite.environment.primaite_env import Primaite -# from primaite.transactions.transactions_to_file import \ -# write_transaction_to_file -# -# _LOGGER = getLogger(__name__) -# -# -# def run_generic(env: Primaite, config_values: TrainingConfig): -# """ -# Run against a generic agent. -# -# :param env: An instance of -# :class:`~primaite.environment.primaite_env.Primaite`. -# :param config_values: An instance of -# :class:`~primaite.config.training_config.TrainingConfig`. -# """ -# for episode in range(0, config_values.num_episodes): -# env.reset() -# for step in range(0, config_values.num_steps): -# # Send the observation space to the agent to get an action -# # TEMP - random action for now -# # action = env.blue_agent_action(obs) -# action = env.action_space.sample() -# -# # Run the simulation step on the live environment -# obs, reward, done, info = env.step(action) -# -# # Break if done is True -# if done: -# break -# -# # Introduce a delay between steps -# time.sleep(config_values.time_delay / 1000) -# -# # Reset the environment at the end of the episode -# -# env.close() -# -# -# def run_stable_baselines3_ppo( -# env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -# ): -# """ -# Run against a stable_baselines3 PPO agent. -# -# :param env: An instance of -# :class:`~primaite.environment.primaite_env.Primaite`. -# :param config_values: An instance of -# :class:`~primaite.config.training_config.TrainingConfig`. -# :param session_path: The directory path the session is writing to. -# :param timestamp_str: The session timestamp in the format: -# _. -# """ -# if config_values.load_agent: -# try: -# agent = PPO.load( -# config_values.agent_load_file, -# env, -# verbose=0, -# n_steps=config_values.num_steps, -# ) -# except Exception: -# print( -# "ERROR: Could not load agent at location: " -# + config_values.agent_load_file -# ) -# _LOGGER.error("Could not load agent") -# _LOGGER.error("Exception occured", exc_info=True) -# else: -# agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) -# -# if config_values.session_type == "TRAINING": -# # We're in a training session -# print("Starting training session...") -# _LOGGER.debug("Starting training session...") -# for episode in range(config_values.num_episodes): -# agent.learn(total_timesteps=config_values.num_steps) -# _save_agent(agent, session_path, timestamp_str) -# else: -# # Default to being in an evaluation session -# print("Starting evaluation session...") -# _LOGGER.debug("Starting evaluation session...") -# evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) -# -# env.close() -# -# -# -# -# def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): -# """ -# Persist an agent. -# -# Only works for stable baselines3 agents at present. -# -# :param session_path: The directory path the session is writing to. -# :param timestamp_str: The session timestamp in the format: -# _. -# """ -# if not isinstance(agent, OnPolicyAlgorithm): -# msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." -# _LOGGER.error(msg) -# else: -# filepath = session_path / f"agent_saved_{timestamp_str}" -# agent.save(filepath) -# _LOGGER.debug(f"Trained agent saved as: {filepath}") -# -# -# -# -# def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): -# """Run the PrimAITE Session. -# -# :param training_config_path: The training config filepath. -# :param lay_down_config_path: The lay down config filepath. -# """ -# # Welcome message -# print("Welcome to the Primary-level AI Training Environment (PrimAITE)") -# uuid = str(uuid4()) -# session_timestamp: Final[datetime] = datetime.now() -# session_path = _get_session_path(session_timestamp) -# timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") -# -# print(f"The output directory for this session is: {session_path}") -# -# # Create a list of transactions -# # A transaction is an object holding the: -# # - episode # -# # - step # -# # - initial observation space -# # - action -# # - reward -# # - new observation space -# transaction_list = [] -# -# # Create the Primaite environment -# env = Primaite( -# training_config_path=training_config_path, -# lay_down_config_path=lay_down_config_path, -# transaction_list=transaction_list, -# session_path=session_path, -# timestamp_str=timestamp_str, -# ) -# -# print("Writing Session Metadata file...") -# -# _write_session_metadata_file( -# session_path=session_path, uuid=uuid, session_timestamp=session_timestamp, env=env -# ) -# -# config_values = env.training_config -# -# # Get the number of steps (which is stored in the child config file) -# config_values.num_steps = env.episode_steps -# -# # Run environment against an agent -# if config_values.agent_identifier == "GENERIC": -# run_generic(env=env, config_values=config_values) -# elif config_values.agent_identifier == "STABLE_BASELINES3_PPO": -# run_stable_baselines3_ppo( -# env=env, -# config_values=config_values, -# session_path=session_path, -# timestamp_str=timestamp_str, -# ) -# elif config_values.agent_identifier == "STABLE_BASELINES3_A2C": -# run_stable_baselines3_a2c( -# env=env, -# config_values=config_values, -# session_path=session_path, -# timestamp_str=timestamp_str, -# ) -# -# print("Session finished") -# _LOGGER.debug("Session finished") -# -# print("Saving transaction logs...") -# write_transaction_to_file( -# transaction_list=transaction_list, -# session_path=session_path, -# timestamp_str=timestamp_str, -# ) -# -# print("Updating Session Metadata file...") -# _update_session_metadata_file(session_path=session_path, env=env) -# -# print("Finished") -# _LOGGER.debug("Finished") -# -# -# if __name__ == "__main__": -# parser = argparse.ArgumentParser() -# parser.add_argument("--tc") -# parser.add_argument("--ldc") -# args = parser.parse_args() -# if not args.tc: -# _LOGGER.error( -# "Please provide a training config file using the --tc " "argument" -# ) -# if not args.ldc: -# _LOGGER.error( -# "Please provide a lay down config file using the --ldc " "argument" -# ) -# run(training_config_path=args.tc, lay_down_config_path=args.ldc) -# -# +# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +""" +The main PrimAITE session runner module. + +TODO: This will eventually be refactored out into a proper Session class. +TODO: The passing about of session_path and timestamp_str is temporary and + will be cleaned up once we move to a proper Session class. +""" +import argparse +import json +import time +from datetime import datetime +from pathlib import Path +from typing import Final, Union +from uuid import uuid4 + +from stable_baselines3 import A2C, PPO +from stable_baselines3.common.evaluation import evaluate_policy +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.ppo import MlpPolicy as PPOMlp + +from primaite import SESSIONS_DIR, getLogger +from primaite.config.training_config import TrainingConfig +from primaite.environment.primaite_env import Primaite +from primaite.primaite_session import PrimaiteSession +from primaite.transactions.transactions_to_file import \ + write_transaction_to_file + +_LOGGER = getLogger(__name__) + + +def run_generic(env: Primaite, config_values: TrainingConfig): + """ + Run against a generic agent. + + :param env: An instance of + :class:`~primaite.environment.primaite_env.Primaite`. + :param config_values: An instance of + :class:`~primaite.config.training_config.TrainingConfig`. + """ + for episode in range(0, config_values.num_episodes): + env.reset() + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + action = env.action_space.sample() + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + + env.close() + + +def run_stable_baselines3_ppo( + env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str +): + """ + Run against a stable_baselines3 PPO agent. + + :param env: An instance of + :class:`~primaite.environment.primaite_env.Primaite`. + :param config_values: An instance of + :class:`~primaite.config.training_config.TrainingConfig`. + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. + """ + if config_values.load_agent: + try: + agent = PPO.load( + config_values.agent_load_file, + env, + verbose=0, + n_steps=config_values.num_steps, + ) + except Exception: + print( + "ERROR: Could not load agent at location: " + + config_values.agent_load_file + ) + _LOGGER.error("Could not load agent") + _LOGGER.error("Exception occured", exc_info=True) + else: + agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) + + if config_values.session_type == "TRAINING": + # We're in a training session + print("Starting training session...") + _LOGGER.debug("Starting training session...") + for episode in range(config_values.num_episodes): + agent.learn(total_timesteps=config_values.num_steps) + _save_agent(agent, session_path, timestamp_str) + else: + # Default to being in an evaluation session + print("Starting evaluation session...") + _LOGGER.debug("Starting evaluation session...") + evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) + + env.close() + + + + +def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): + """ + Persist an agent. + + Only works for stable baselines3 agents at present. + + :param session_path: The directory path the session is writing to. + :param timestamp_str: The session timestamp in the format: + _. + """ + if not isinstance(agent, OnPolicyAlgorithm): + msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." + _LOGGER.error(msg) + else: + filepath = session_path / f"agent_saved_{timestamp_str}" + agent.save(filepath) + _LOGGER.debug(f"Trained agent saved as: {filepath}") + + + + +def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): + """Run the PrimAITE Session. + + :param training_config_path: The training config filepath. + :param lay_down_config_path: The lay down config filepath. + """ + session = PrimaiteSession(training_config_path, lay_down_config_path) + + session.setup() + session.learn() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--tc") + parser.add_argument("--ldc") + args = parser.parse_args() + if not args.tc: + _LOGGER.error( + "Please provide a training config file using the --tc " "argument" + ) + if not args.ldc: + _LOGGER.error( + "Please provide a lay down config file using the --ldc " "argument" + ) + run(training_config_path=args.tc, lay_down_config_path=args.ldc) + + diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 8f3380c8..a4148d12 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -8,10 +8,10 @@ from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR from primaite.agents.agent import AgentSessionABC -from primaite.agents.rllib import RLlibPPO -from primaite.agents.sb3 import SB3PPO +from primaite.agents.rllib import RLlibAgent +from primaite.agents.sb3 import SB3Agent from primaite.common.enums import AgentFramework, RedAgentIdentifier, \ - ActionType + ActionType, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite @@ -95,35 +95,19 @@ class PrimaiteSession: pass elif self._training_config.agent_framework == AgentFramework.SB3: - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: - # Stable Baselines3/Proximal Policy Optimization - self._agent_session = SB3PPO( - self._training_config_path, - self._lay_down_config_path - ) - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: - # Stable Baselines3/Advantage Actor Critic - raise NotImplementedError - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass + # Stable Baselines3 Agent + self._agent_session = SB3Agent( + self._training_config_path, + self._lay_down_config_path + ) elif self._training_config.agent_framework == AgentFramework.RLLIB: - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: - # Ray RLlib/Proximal Policy Optimization - self._agent_session = RLlibPPO( - self._training_config_path, - self._lay_down_config_path - ) + # Ray RLlib Agent + self._agent_session = RLlibAgent( + self._training_config_path, + self._lay_down_config_path + ) - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: - # Ray RLlib/Advantage Actor Critic - raise NotImplementedError - - else: - # Invalid AgentFramework RedAgentIdentifier combo - pass else: # Invalid AgentFramework pass @@ -134,7 +118,8 @@ class PrimaiteSession: episodes: Optional[int] = None, **kwargs ): - self._agent_session.learn(time_steps, episodes, **kwargs) + if not self._training_config.session_type == SessionType.EVALUATION: + self._agent_session.learn(time_steps, episodes, **kwargs) def evaluate( self, @@ -142,18 +127,5 @@ class PrimaiteSession: episodes: Optional[int] = None, **kwargs ): - self._agent_session.evaluate(time_steps, episodes, **kwargs) - - @classmethod - def import_agent( - cls, - gent_path: str, - training_config_path: str, - lay_down_config_path: str - ) -> PrimaiteSession: - session = PrimaiteSession(training_config_path, lay_down_config_path) - - # Reset the UUID - session._uuid = "" - - return session + if not self._training_config.session_type == SessionType.TRAINING: + self._agent_session.evaluate(time_steps, episodes, **kwargs) diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 24581597..ed7a8f1c 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -4,6 +4,8 @@ import csv from pathlib import Path +import numpy as np + from primaite import getLogger _LOGGER = getLogger(__name__) @@ -54,8 +56,12 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action # space as "AS_1" # This will be tied into the PrimAITE Use Case so that they make sense + template_transation = transaction_list[0] - action_length = template_transation.action_space.size + if isinstance(template_transation.action_space, int): + action_length = template_transation.action_space + else: + action_length = template_transation.action_space.size obs_shape = template_transation.obs_space_post.shape obs_assets = template_transation.obs_space_post.shape[0] if len(obs_shape) == 1: From a2cc4233b5917f63328a846e9a52a13c7e9be403 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 20 Jun 2023 16:06:55 +0100 Subject: [PATCH 06/43] #917 -Finished integrating all agents to either train (policy agents) or evaluate (hard-coded agents). Still some fixing up to do, tidying up, loading etc. also docs. But this is all now working. --- docs/source/config.rst | 2 +- src/primaite/VERSION | 2 +- src/primaite/agents/agent.py | 73 +++- src/primaite/agents/hardcoded_acl.py | 376 ++++++++++++++++ src/primaite/agents/hardcoded_node.py | 97 +++++ src/primaite/agents/rllib.py | 20 +- src/primaite/agents/sb3.py | 19 +- src/primaite/agents/simple.py | 60 +++ src/primaite/agents/utils.py | 401 +++++++++++++++++- src/primaite/common/enums.py | 28 +- .../training/training_config_main.yaml | 31 +- src/primaite/config/training_config.py | 37 +- src/primaite/environment/primaite_env.py | 2 +- src/primaite/main.py | 128 +----- src/primaite/primaite_session.py | 67 ++- tests/config/legacy/new_training_config.yaml | 2 +- 16 files changed, 1125 insertions(+), 220 deletions(-) create mode 100644 src/primaite/agents/hardcoded_acl.py create mode 100644 src/primaite/agents/hardcoded_node.py create mode 100644 src/primaite/agents/simple.py diff --git a/docs/source/config.rst b/docs/source/config.rst index 1bea0671..52748eec 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -29,7 +29,7 @@ The environment config file consists of the following attributes: * SB3 - Stable Baselines3 * RLLIB - Ray RLlib. -* **red_agent_identifier** +* **agent_identifier** This identifies the agent to use for the session. Select from one of the following: diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 0da493b5..3068ee27 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0b1 \ No newline at end of file +2.0.0rc1 \ No newline at end of file diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 34ad0adb..812072ba 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -1,4 +1,5 @@ import json +import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path @@ -12,7 +13,6 @@ from primaite.config import training_config from primaite.config.training_config import TrainingConfig from primaite.environment.primaite_env import Primaite - _LOGGER = getLogger(__name__) @@ -196,50 +196,77 @@ class AgentSessionABC(ABC): pass -class DeterministicAgentSessionABC(AgentSessionABC): - @abstractmethod - def __init__( - self, - training_config_path, - lay_down_config_path - ): - self._training_config_path = training_config_path - self._lay_down_config_path = lay_down_config_path - self._env: Primaite - self._agent = None +class HardCodedAgentSessionABC(AgentSessionABC): + def __init__(self, training_config_path, lay_down_config_path): + super().__init__(training_config_path, lay_down_config_path) + self._setup() - @abstractmethod def _setup(self): + self._env: Primaite = Primaite( + training_config_path=self._training_config_path, + lay_down_config_path=self._lay_down_config_path, + transaction_list=[], + session_path=self.session_path, + timestamp_str=self.timestamp_str + ) + super()._setup() + self._can_learn = False + self._can_evaluate = True + + + def _save_checkpoint(self): pass - @abstractmethod def _get_latest_checkpoint(self): pass def learn( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): _LOGGER.warning("Deterministic agents cannot learn") @abstractmethod + def _calculate_action(self, obs): + pass + def evaluate( self, time_steps: Optional[int] = None, - episodes: Optional[int] = None + episodes: Optional[int] = None, + **kwargs ): - pass + if not time_steps: + time_steps = self._training_config.num_steps + + if not episodes: + episodes = self._training_config.num_episodes + + for episode in range(episodes): + # Reset env and collect initial observation + obs = self._env.reset() + for step in range(time_steps): + # Calculate action + action = self._calculate_action(obs) + + # Perform the step + obs, reward, done, info = self._env.step(action) + + if done: + break + + # Introduce a delay between steps + time.sleep(self._training_config.time_delay / 1000) + self._env.close() @classmethod - @abstractmethod def load(cls): - pass + _LOGGER.warning("Deterministic agents cannot be loaded") - @abstractmethod def save(self): - pass + _LOGGER.warning("Deterministic agents cannot be saved") - @abstractmethod def export(self): - pass + _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py new file mode 100644 index 00000000..4ad08f6e --- /dev/null +++ b/src/primaite/agents/hardcoded_acl.py @@ -0,0 +1,376 @@ +import numpy as np + +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import ( + get_new_action, + get_node_of_ip, + transform_action_acl_enum, + transform_change_obs_readable, +) +from primaite.common.enums import HardCodedAgentView + + +class HardCodedACLAgent(HardCodedAgentSessionABC): + + def _calculate_action(self, obs): + if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: + # Basic view action using only the current observation + return self._calculate_action_basic_view(obs) + else: + # full view action using observation space, action + # history and reward feedback + return self._calculate_action_full_view(obs) + + def get_blocked_green_iers(self, green_iers, acl, nodes): + blocked_green_iers = {} + + for green_ier_id, green_ier in green_iers.items(): + source_node_id = green_ier.get_source_node_id() + source_node_address = nodes[source_node_id].ip_address + dest_node_id = green_ier.get_dest_node_id() + dest_node_address = nodes[dest_node_id].ip_address + protocol = green_ier.get_protocol() # e.g. 'TCP' + port = green_ier.get_port() + + # Can be blocked by an ACL or by default (no allow rule exists) + if acl.is_blocked(source_node_address, dest_node_address, protocol, + port): + blocked_green_iers[green_ier_id] = green_ier + + return blocked_green_iers + + def get_matching_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get matching ACL rules for an IER. + """ + + source_node_id = ier.get_source_node_id() + source_node_address = nodes[source_node_id].ip_address + dest_node_id = ier.get_dest_node_id() + dest_node_address = nodes[dest_node_id].ip_address + protocol = ier.get_protocol() # e.g. 'TCP' + port = ier.get_port() + + matching_rules = acl.get_relevant_rules(source_node_address, + dest_node_address, protocol, + port) + return matching_rules + + def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get blocking ACL rules for an IER. + Warning: Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked) + """ + + matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) + + blocked_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "DENY": + blocked_rules[rule_key] = rule_value + + return blocked_rules + + def get_allow_acl_rules_for_ier(self, ier, acl, nodes): + """ + Get all allowing ACL rules for an IER. + """ + + matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "ALLOW": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def get_matching_acl_rules(self, source_node_id, dest_node_id, protocol, + port, acl, + nodes, services_list): + if source_node_id != "ANY": + source_node_address = nodes[str(source_node_id)].ip_address + else: + source_node_address = source_node_id + + if dest_node_id != "ANY": + dest_node_address = nodes[str(dest_node_id)].ip_address + else: + dest_node_address = dest_node_id + + if protocol != "ANY": + protocol = services_list[ + protocol - 1] # -1 as dont have to account for ANY in list of services + + matching_rules = acl.get_relevant_rules(source_node_address, + dest_node_address, protocol, + port) + return matching_rules + + def get_allow_acl_rules(self, source_node_id, dest_node_id, protocol, + port, acl, + nodes, services_list): + matching_rules = self.get_matching_acl_rules(source_node_id, + dest_node_id, + protocol, port, acl, + nodes, + services_list) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "ALLOW": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def get_deny_acl_rules(self, source_node_id, dest_node_id, protocol, port, + acl, + nodes, services_list): + matching_rules = self.get_matching_acl_rules(source_node_id, + dest_node_id, + protocol, port, acl, + nodes, + services_list) + + allowed_rules = {} + for rule_key, rule_value in matching_rules.items(): + if rule_value.get_permission() == "DENY": + allowed_rules[rule_key] = rule_value + + return allowed_rules + + def _calculate_action_full_view(self, obs): + """ + Given an observation and the environment calculate a good acl-based action for the blue agent to take + + Knowledge of just the observation space is insufficient for a perfect solution, as we need to know: + + - Which ACL rules already exist, - otherwise: + - The agent would perminently get stuck in a loop of performing the same action over and over. + (best action is to block something, but its already blocked but doesn't know this) + - The agent would be unable to interact with existing rules (e.g. how would it know to delete a rule, + if it doesnt know what rules exist) + - The Green IERs (optional) - It often needs to know which traffic it should be allowing. For example + in the default config one of the green IERs is blocked by default, but it has no way of knowing this + based on the observation space. Additionally, potentially in the future, once a node state + has been fixed (no longer compromised), it needs a way to know it should reallow traffic. + A RL agent can learn what the green IERs are on its own - but the rule based agent cannot easily do this. + + There doesn't seem like there's much that can be done if an Operating or OS State is compromised + + If a service node becomes compromised there's a decision to make - do we block that service? + Pros: It cannot launch an attack on another node, so the node will not be able to be OVERWHELMED + Cons: Will block a green IER, decreasing the reward + We decide to block the service. + + Potentially a better solution (for the reward) would be to block the incomming traffic from compromised + nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing + an overwhelmed state, so we don't do this. + + """ + #obs = convert_to_old_obs(obs) + r_obs = transform_change_obs_readable(obs) + _, _, _, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + # 1. Check if node is compromised. If so we want to block its outwards services + # a. If it is comprimised check if there's an allow rule we should delete. + # cons: might delete a multi-rule from any source node (ANY -> x) + # b. OPTIONAL (Deny rules not needed): Check if there already exists an existing Deny Rule so not to duplicate + # c. OPTIONAL (no allow rule = blocked): Add a DENY rule + found_action = False + for service_num, service_states in enumerate(s): + for x, service_state in enumerate(service_states): + if service_state == "COMPROMISED": + + action_source_id = x + 1 # +1 as 0 is any + action_destination_id = "ANY" + action_protocol = service_num + 1 # +1 as 0 is any + action_port = "ANY" + + allow_rules = self.get_allow_acl_rules( + action_source_id, + action_destination_id, + action_protocol, + action_port, + self._env.acl, + self._env.nodes, + self._env.services_list, + ) + deny_rules = self.get_deny_acl_rules( + action_source_id, + action_destination_id, + action_protocol, + action_port, + self._env.acl, + self._env.nodes, + self._env.services_list, + ) + if len(allow_rules) > 0: + # Check if there's an allow rule we should delete + rule = list(allow_rules.values())[0] + action_decision = "DELETE" + action_permission = "ALLOW" + action_source_ip = rule.get_source_ip() + action_source_id = int( + get_node_of_ip(action_source_ip, self._env.nodes)) + action_destination_ip = rule.get_dest_ip() + action_destination_id = int( + get_node_of_ip(action_destination_ip, + self._env.nodes)) + action_protocol_name = rule.get_protocol() + action_protocol = ( + self._env.services_list.index( + action_protocol_name) + 1 + ) # convert name e.g. 'TCP' to index + action_port_name = rule.get_port() + action_port = self._env.ports_list.index( + action_port_name) + 1 # convert port name e.g. '80' to index + + found_action = True + break + elif len(deny_rules) > 0: + # TODO OPTIONAL + # If there's already a DENY RULE, that blocks EVERYTHING from the source ip we don't need + # to create another + # Check to see if the DENY rule really blocks everything (ANY) or just a specific rule + continue + else: + # TODO OPTIONAL: Add a DENY rule, optional as by default no allow rule == blocked + action_decision = "CREATE" + action_permission = "DENY" + break + if found_action: + break + + # 2. If NO Node is Comprimised, or the node has already been blocked, check the green IERs and + # add an Allow rule if the green IER is being blocked. + # a. OPTIONAL - NOT IMPLEMENTED (optional as a deny rule does not overwrite an allow rule): + # If there's a DENY rule delete it if: + # - There isn't already a deny rule + # - It doesnt allows a comprimised node to become operational. + # b. Add an ALLOW rule if: + # - There isn't already an allow rule + # - It doesnt allows a comprimised node to become operational + + if not found_action: + # Which Green IERS are blocked + blocked_green_iers = self.get_blocked_green_iers( + self._env.green_iers, self._env.acl, + self._env.nodes) + for ier_key, ier in blocked_green_iers.items(): + + # Which ALLOW rules are allowing this IER (none) + allowing_rules = self.get_allow_acl_rules_for_ier(ier, + self._env.acl, + self._env.nodes) + + # If there are no blocking rules, it may be being blocked by default + # If there is already an allow rule + node_id_to_check = int(ier.get_source_node_id()) + service_name_to_check = ier.get_protocol() + service_id_to_check = self._env.services_list.index( + service_name_to_check) + + # Service state of the the source node in the ier + service_state = s[service_id_to_check][node_id_to_check - 1] + + if len(allowing_rules) == 0 and service_state != "COMPROMISED": + action_decision = "CREATE" + action_permission = "ALLOW" + action_source_id = int(ier.get_source_node_id()) + action_destination_id = int(ier.get_dest_node_id()) + action_protocol_name = ier.get_protocol() + action_protocol = self._env.services_list.index( + action_protocol_name) + 1 # convert name e.g. 'TCP' to index + action_port_name = ier.get_port() + action_port = self._env.ports_list.index( + action_port_name) + 1 # convert port name e.g. '80' to index + + found_action = True + break + + if found_action: + action = [ + action_decision, + action_permission, + action_source_id, + action_destination_id, + action_protocol, + action_port, + ] + action = transform_action_acl_enum(action) + action = get_new_action(action, self._env.action_dict) + else: + # If no good/useful action has been found, just perform a nothing action + action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + action = transform_action_acl_enum(action) + action = get_new_action(action, self._env.action_dict) + return action + + def _calculate_action_basic_view(self, obs): + """Given an observation calculate a good acl-based action for the blue agent to take + + Uses ONLY information from the current observation with NO knowledge of previous actions taken and + NO reward feedback. + + We rely on randomness to select the precise action, as we want to block all traffic originating from + a compromised node, without being able to tell: + 1. Which ACL rules already exist + 1. Which actions the agent has already tried. + + There is a high probability that the correct rule will not be deleted before the state becomes overwhelmed. + + Currently a deny rule does not overwrite an allow rule. The allow rules must be deleted. + """ + action_dict = self._env.action_dict + r_obs = transform_change_obs_readable(obs) + _, o, _, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + number_of_nodes = len( + [i for i in o if i != "NONE"]) # number of nodes (not links) + for service_num, service_states in enumerate(s): + comprimised_states = [n for n, i in enumerate(service_states) if + i == "COMPROMISED"] + if len(comprimised_states) == 0: + # No states are COMPROMISED, try the next service + continue + + compromised_node = np.random.choice( + comprimised_states) + 1 # +1 as 0 would be any + action_decision = "DELETE" + action_permission = "ALLOW" + action_source_ip = compromised_node + # Randomly select a destination ID to block + action_destination_ip = np.random.choice( + list(range(1, number_of_nodes + 1)) + ["ANY"]) + action_destination_ip = int( + action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip + action_protocol = service_num + 1 # +1 as 0 is any + # Randomly select a port to block + # Bad assumption that number of protocols equals number of ports AND no rules exist with an ANY port + action_port = np.random.choice(list(range(1, len(s) + 1))) + + action = [ + action_decision, + action_permission, + action_source_ip, + action_destination_ip, + action_protocol, + action_port, + ] + action = transform_action_acl_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # If no good/useful action has been found, just perform a nothing action + nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + nothing_action = transform_action_acl_enum(nothing_action) + nothing_action = get_new_action(nothing_action, action_dict) + return nothing_action diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py new file mode 100644 index 00000000..6db43da6 --- /dev/null +++ b/src/primaite/agents/hardcoded_node.py @@ -0,0 +1,97 @@ +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import ( + get_new_action, + transform_change_obs_readable, +) +from primaite.agents.utils import ( + transform_action_node_enum, +) + + +class HardCodedNodeAgent(HardCodedAgentSessionABC): + def _calculate_action(self, obs): + """Given an observation calculate a good node-based action for the blue agent to take""" + action_dict = self._env.action_dict + r_obs = transform_change_obs_readable(obs) + _, o, os, *s = r_obs + + if len(r_obs) == 4: # only 1 service + s = [*s] + + # Check in order of most important states (order doesn't currently matter, but it probably should) + # First see if any OS states are compromised + for x, os_state in enumerate(os): + if os_state == "COMPROMISED": + action_node_id = x + 1 + action_node_property = "OS" + property_action = "PATCHING" + action_service_index = 0 # does nothing isn't relevant for os + action = [action_node_id, action_node_property, + property_action, action_service_index] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Next, see if any Services are compromised + # We fix the compromised state before overwhelemd state, + # If a compromised entry node is fixed before the overwhelmed state is triggered, instruction is ignored + for service_num, service in enumerate(s): + for x, service_state in enumerate(service): + if service_state == "COMPROMISED": + action_node_id = x + 1 + action_node_property = "SERVICE" + property_action = "PATCHING" + action_service_index = service_num + + action = [action_node_id, action_node_property, + property_action, action_service_index] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Next, See if any services are overwhelmed + # perhaps this should be fixed automatically when the compromised PCs issues are also resolved + # Currently there's no reason that an Overwhelmed state cannot be resolved before resolving the compromised PCs + + for service_num, service in enumerate(s): + for x, service_state in enumerate(service): + if service_state == "OVERWHELMED": + action_node_id = x + 1 + action_node_property = "SERVICE" + property_action = "PATCHING" + action_service_index = service_num + + action = [action_node_id, action_node_property, + property_action, action_service_index] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # Finally, turn on any off nodes + for x, operating_state in enumerate(o): + if os_state == "OFF": + action_node_id = x + 1 + action_node_property = "OPERATING" + property_action = "ON" # Why reset it when we can just turn it on + action_service_index = 0 # does nothing isn't relevant for operating state + action = [action_node_id, action_node_property, + property_action, action_service_index] + action = transform_action_node_enum(action, action_dict) + action = get_new_action(action, action_dict) + # We can only perform 1 action on each step + return action + + # If no good actions, just go with an action that wont do any harm + action_node_id = 1 + action_node_property = "NONE" + property_action = "ON" + action_service_index = 0 + action = [action_node_id, action_node_property, property_action, + action_service_index] + action = transform_action_node_enum(action) + action = get_new_action(action, action_dict) + + return action diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 67ba6213..7d0cde60 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,21 +1,19 @@ import json from datetime import datetime from pathlib import Path -from pathlib import Path from typing import Optional from ray.rllib.algorithms import Algorithm -from ray.rllib.algorithms.ppo import PPOConfig from ray.rllib.algorithms.a2c import A2CConfig +from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import AgentFramework, RedAgentIdentifier +from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite - _LOGGER = getLogger(__name__) def _env_creator(env_config): @@ -51,13 +49,13 @@ class RLlibAgent(AgentSessionABC): f"got {self._training_config.agent_framework}") _LOGGER.error(msg) raise ValueError(msg) - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + if self._training_config.agent_identifier == AgentIdentifier.PPO: self._agent_config_class = PPOConfig - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_config_class = A2CConfig else: - msg = ("Expected PPO or A2C red_agent_identifier, " - f"got {self._training_config.red_agent_identifier.value}") + msg = ("Expected PPO or A2C agent_identifier, " + f"got {self._training_config.agent_identifier.value}") _LOGGER.error(msg) raise ValueError(msg) self._agent_config: PPOConfig @@ -67,8 +65,8 @@ class RLlibAgent(AgentSessionABC): _LOGGER.debug( f"Created {self.__class__.__name__} using: " f"agent_framework={self._training_config.agent_framework}, " - f"red_agent_identifier=" - f"{self._training_config.red_agent_identifier}, " + f"agent_identifier=" + f"{self._training_config.agent_identifier}, " f"deep_learning_framework=" f"{self._training_config.deep_learning_framework}" ) @@ -117,7 +115,7 @@ class RLlibAgent(AgentSessionABC): train_batch_size=self._training_config.num_steps ) self._agent_config.framework( - framework=self._training_config.deep_learning_framework + framework="torch" ) self._agent_config.rollouts( diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 3cd2e50a..3748b57d 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -2,12 +2,12 @@ from typing import Optional import numpy as np from stable_baselines3 import PPO, A2C +from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import RedAgentIdentifier, AgentFramework +from primaite.common.enums import AgentIdentifier, AgentFramework from primaite.environment.primaite_env import Primaite -from stable_baselines3.ppo import MlpPolicy as PPOMlp _LOGGER = getLogger(__name__) @@ -24,13 +24,13 @@ class SB3Agent(AgentSessionABC): f"got {self._training_config.agent_framework}") _LOGGER.error(msg) raise ValueError(msg) - if self._training_config.red_agent_identifier == RedAgentIdentifier.PPO: + if self._training_config.agent_identifier == AgentIdentifier.PPO: self._agent_class = PPO - elif self._training_config.red_agent_identifier == RedAgentIdentifier.A2C: + elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_class = A2C else: - msg = ("Expected PPO or A2C red_agent_identifier, " - f"got {self._training_config.red_agent_identifier.value}") + msg = ("Expected PPO or A2C agent_identifier, " + f"got {self._training_config.agent_identifier.value}") _LOGGER.error(msg) raise ValueError(msg) @@ -40,8 +40,8 @@ class SB3Agent(AgentSessionABC): _LOGGER.debug( f"Created {self.__class__.__name__} using: " f"agent_framework={self._training_config.agent_framework}, " - f"red_agent_identifier=" - f"{self._training_config.red_agent_identifier}" + f"agent_identifier=" + f"{self._training_config.agent_identifier}" ) def _setup(self): @@ -56,7 +56,7 @@ class SB3Agent(AgentSessionABC): self._agent = self._agent_class( PPOMlp, self._env, - verbose=self._training_config.output_verbose_level, + verbose=self.output_verbose_level, n_steps=self._training_config.num_steps, tensorboard_log=self._tensorboard_log_path ) @@ -118,6 +118,7 @@ class SB3Agent(AgentSessionABC): action = np.int64(action) obs, rewards, done, info = self._env.step(action) + @classmethod def load(self): raise NotImplementedError diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py new file mode 100644 index 00000000..cf333b1e --- /dev/null +++ b/src/primaite/agents/simple.py @@ -0,0 +1,60 @@ +from primaite.agents.agent import HardCodedAgentSessionABC +from primaite.agents.utils import ( + get_new_action, + transform_action_acl_enum, + transform_action_node_enum, +) + + +class RandomAgent(HardCodedAgentSessionABC): + """ + A Random Agent. + + Get a completely random action from the action space. + """ + + def _calculate_action(self, obs): + return self._env.action_space.sample() + + +class DummyAgent(HardCodedAgentSessionABC): + """ + A Dummy Agent. + + All action spaces setup so dummy action is always 0 regardless of action + type used. + """ + + def _calculate_action(self, obs): + return 0 + + +class DoNothingACLAgent(HardCodedAgentSessionABC): + """ + A do nothing ACL agent. + + A valid ACL action that has no effect; does nothing. + """ + + def _calculate_action(self, obs): + nothing_action = ["NONE", "ALLOW", "ANY", "ANY", "ANY", "ANY"] + nothing_action = transform_action_acl_enum(nothing_action) + nothing_action = get_new_action(nothing_action, self._env.action_dict) + + return nothing_action + + +class DoNothingNodeAgent(HardCodedAgentSessionABC): + """ + A do nothing Node agent. + + A valid Node action that has no effect; does nothing. + """ + + def _calculate_action(self, obs): + nothing_action = [1, "NONE", "ON", 0] + nothing_action = transform_action_node_enum(nothing_action) + nothing_action = get_new_action(nothing_action, self._env.action_dict) + # nothing_action should currently always be 0 + + return nothing_action diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index bb967906..acc71590 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -1,4 +1,13 @@ -from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction +import numpy as np + +from primaite.common.enums import ( + HardwareState, + LinkStatus, + NodeHardwareAction, + NodeSoftwareAction, + SoftwareState, +) +from primaite.common.enums import NodePOLType def transform_action_node_readable(action): @@ -125,3 +134,393 @@ def is_valid_acl_action_extra(action): return False return True + + + +def transform_change_obs_readable(obs): + """Transform list of transactions to readable list of each observation property + + example: + + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] + """ + ids = [i for i in obs[:, 0]] + operating_states = [HardwareState(i).name for i in obs[:, 1]] + os_states = [SoftwareState(i).name for i in obs[:, 2]] + new_obs = [ids, operating_states, os_states] + + for service in range(3, obs.shape[1]): + # Links bit/s don't have a service state + service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] + new_obs.append(service_states) + + return new_obs + + +def transform_obs_readable(obs): + """ + example: + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] + """ + + changed_obs = transform_change_obs_readable(obs) + new_obs = list(zip(*changed_obs)) + # Convert list of tuples to list of lists + new_obs = [list(i) for i in new_obs] + + return new_obs + + +def convert_to_new_obs(obs, num_nodes=10): + """Convert original gym Box observation space to new multiDiscrete observation space""" + # Remove ID columns, remove links and flatten to MultiDiscrete observation space + new_obs = obs[:num_nodes, 1:].flatten() + return new_obs + + +def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): + """ + Convert to old observation, links filled with 0's as no information is included in new observation space + + example: + obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) + + new_obs = array([[ 1, 1, 1, 1], + [ 2, 1, 1, 1], + [ 3, 1, 1, 1], + ... + [20, 0, 0, 0]]) + """ + + # Convert back to more readable, original format + reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) + + # Add empty links back and add node ID back + s = np.zeros([reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], dtype=np.int64) + s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back + s[:num_nodes, 1:] = reshaped_nodes # put values back in + new_obs = s + + # Add links back in + links = obs[-num_links:] + # Links will be added to the last protocol/service slot but they are not specific to that service + new_obs[num_nodes:, -1] = links + + return new_obs + + +def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): + """Return string describing change between two observations + + example: + obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) + obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) + output = 'ID 1: SERVICE 2 set to GOOD' + + """ + obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) + obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) + list_of_changes = [] + for n, row in enumerate(obs1 - obs2): + if row.any() != 0: + relevant_changes = np.where(row != 0, obs2[n], -1) + relevant_changes[0] = obs2[n, 0] # ID is always relevant + is_link = relevant_changes[0] > num_nodes + desc = _describe_obs_change_helper(relevant_changes, is_link) + list_of_changes.append(desc) + + change_string = "\n ".join(list_of_changes) + if len(list_of_changes) > 0: + change_string = "\n " + change_string + return change_string + + +def _describe_obs_change_helper(obs_change, is_link): + """ " + Helper funcion to describe what has changed + + example: + [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" + + Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' + + """ + # Indexes where a change has occured, not including 0th index + index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] + # Node pol types, Indexes >= 3 are service nodes + NodePOLTypes = [ + NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed + ] + # Account for hardware states, software sattes and links + states = [ + LinkStatus(obs_change[i]).name + if is_link + else HardwareState(obs_change[i]).name + if i == 1 + else SoftwareState(obs_change[i]).name + for i in index_changed + ] + + if not is_link: + desc = f"ID {obs_change[0]}:" + for NodePOLType, state in list(zip(NodePOLTypes, states)): + desc = desc + " " + NodePOLType + " changed to " + state + "." + else: + desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." + + return desc + + +def transform_action_node_enum(action): + """ + Convert a node action from readable string format, to enumerated format + + example: + [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] + """ + + action_node_id = action[0] + action_node_property = NodePOLType[action[1]].value + + if action[1] == "OPERATING": + property_action = NodeHardwareAction[action[2]].value + elif action[1] == "OS" or action[1] == "SERVICE": + property_action = NodeSoftwareAction[action[2]].value + else: + property_action = 0 + + action_service_index = action[3] + + new_action = [action_node_id, action_node_property, property_action, action_service_index] + + return new_action + + +def transform_action_node_readable(action): + """ + Convert a node action from enumerated format to readable format + + example: + [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + """ + + action_node_property = NodePOLType(action[1]).name + + if action_node_property == "OPERATING": + property_action = NodeHardwareAction(action[2]).name + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: + property_action = NodeSoftwareAction(action[2]).name + else: + property_action = "NONE" + + new_action = [action[0], action_node_property, property_action, action[3]] + return new_action + + +def node_action_description(action): + """ + Generate string describing a node-based action + """ + + if isinstance(action[1], (int, np.int64)): + # transform action to readable format + action = transform_action_node_readable(action) + + node_id = action[0] + node_property = action[1] + property_action = action[2] + service_id = action[3] + + if property_action == "NONE": + return "" + if node_property == "OPERATING" or node_property == "OS": + description = f"NODE {node_id}, {node_property}, SET TO {property_action}" + elif node_property == "SERVICE": + description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" + else: + return "" + + return description + + +def transform_action_acl_readable(action): + """ + Transform an ACL action to a more readable format + + example: + [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] + """ + + action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} + action_permissions = {0: "DENY", 1: "ALLOW"} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, 0 means any, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == 0: + new_action[n + 2] = "ANY" + + return new_action + + +def transform_action_acl_enum(action): + """ + Convert a acl action from readable string format, to enumerated format + """ + + action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} + action_permissions = {"DENY": 0, "ALLOW": 1} + + action_decision = action_decisions[action[0]] + action_permission = action_permissions[action[1]] + + # For IPs, Ports and Protocols, ANY has value 0, otherwise its just an index + new_action = [action_decision, action_permission] + list(action[2:6]) + for n, val in enumerate(list(action[2:6])): + if val == "ANY": + new_action[n + 2] = 0 + + new_action = np.array(new_action) + return new_action + + +def acl_action_description(action): + """generate string describing a acl-based action""" + + if isinstance(action[0], (int, np.int64)): + # transform action to readable format + action = transform_action_acl_readable(action) + if action[0] == "NONE": + description = "NO ACL RULE APPLIED" + else: + description = ( + f"{action[0]} RULE: {action[1]} traffic from IP {action[2]} to IP {action[3]}," + f" for protocol/service index {action[4]} on port index {action[5]}" + ) + + return description + + +def get_node_of_ip(ip, node_dict): + """ + Get the node ID of an IP address + + node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) + """ + + for node_key, node_value in node_dict.items(): + node_ip = node_value.ip_address + if node_ip == ip: + return node_key + + +def is_valid_node_action(action): + """Is the node action an actual valid action + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch + - Node already being in that state (turning an ON node ON) + """ + action_r = transform_action_node_readable(action) + + node_property = action_r[1] + node_action = action_r[2] + + if node_property == "NONE": + return False + if node_action == "NONE": + return False + if node_property == "OPERATING" and node_action == "PATCHING": + # Operating State cannot PATCH + return False + if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + # Software States can only do Nothing or Patch + return False + return True + + +def is_valid_acl_action(action): + """ + Is the ACL action an actual valid action + + Only uses information about the action to determine if the action has an effect + + Does NOT consider: + - Trying to create identical rules + - Trying to create a rule which is a subset of another rule (caused by "ANY") + """ + action_r = transform_action_acl_readable(action) + + action_decision = action_r[0] + action_permission = action_r[1] + action_source_id = action_r[2] + action_destination_id = action_r[3] + + if action_decision == "NONE": + return False + if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": + # ACL rule towards itself + return False + if action_permission == "DENY": + # DENY is unnecessary, we can create and delete allow rules instead + # No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY. + return False + + return True + + +def is_valid_acl_action_extra(action): + """Harsher version of valid acl actions, does not allow action""" + if is_valid_acl_action(action) is False: + return False + + action_r = transform_action_acl_readable(action) + action_protocol = action_r[4] + action_port = action_r[5] + + # Don't allow protocols or ports to be ANY + # in the future we might want to do the opposite, and only have ANY option for ports and service + if action_protocol == "ANY": + return False + if action_port == "ANY": + return False + + return True + + +def get_new_action(old_action, action_dict): + """Get new action (e.g. 32) from old action e.g. [1,1,1,0] + + old_action can be either node or acl action type + """ + + for key, val in action_dict.items(): + if list(val) == list(old_action): + return key + # Not all possible actions are included in dict, only valid action are + # if action is not in the dict, its an invalid action so return 0 + return 0 + + +def get_action_description(action, action_dict): + """ + Get a string describing/explaining what an action is doing in words + """ + + action_array = action_dict[action] + if len(action_array) == 4: + # node actions have length 4 + action_description = node_action_description(action_array) + elif len(action_array) == 6: + # acl actions have length 6 + action_description = acl_action_description(action_array) + else: + # Should never happen + action_description = "Unrecognised action" + + return action_description diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 89bfd737..191cb782 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -32,6 +32,7 @@ class Priority(Enum): class HardwareState(Enum): """Node hardware state enumeration.""" + NONE = 0 ON = 1 OFF = 2 RESETTING = 3 @@ -42,6 +43,7 @@ class HardwareState(Enum): class SoftwareState(Enum): """Software or Service state enumeration.""" + NONE = 0 GOOD = 1 PATCHING = 2 COMPROMISED = 3 @@ -94,7 +96,8 @@ class VerboseLevel(IntEnum): class AgentFramework(Enum): - NONE = 0 + """The agent algorithm framework/package.""" + CUSTOM = 0 "Custom Agent" SB3 = 1 "Stable Baselines3" @@ -103,7 +106,7 @@ class AgentFramework(Enum): class DeepLearningFramework(Enum): - """The deep learning framework enumeration.""" + """The deep learning framework.""" TF = "tf" "Tensorflow" TF2 = "tf2" @@ -112,15 +115,28 @@ class DeepLearningFramework(Enum): "PyTorch" -class RedAgentIdentifier(Enum): +class AgentIdentifier(Enum): + """The Red Agent algo/class.""" A2C = 1 "Advantage Actor Critic" PPO = 2 "Proximal Policy Optimization" HARDCODED = 3 - "Custom Agent" - RANDOM = 4 - "Custom Agent" + "The Hardcoded agents" + DO_NOTHING = 4 + "The DoNothing agents" + RANDOM = 5 + "The RandomAgent" + DUMMY = 6 + "The DummyAgent" + + +class HardCodedAgentView(Enum): + """The view the deterministic hard-coded agent has of the environment.""" + BASIC = 1 + "The current observation space only" + FULL = 2 + "Full environment view with actions taken and reward feedback" class ActionType(Enum): diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d7b4db98..2cc29c55 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -1,32 +1,41 @@ -# Main Config File +# Training Config File -# Sets which agent algorithm framework will be used: +# Sets which agent algorithm framework will be used. # Options are: # "SB3" (Stable Baselines3) # "RLLIB" (Ray RLlib) -# "NONE" (Custom Agent) +# "CUSTOM" (Custom Agent) agent_framework: RLLIB -# Sets which deep learning framework will be used. Default is TF (Tensorflow). +# Sets which deep learning framework will be used (by RLlib ONLY). +# Default is TF (Tensorflow). # Options are: # "TF" (Tensorflow) # TF2 (Tensorflow 2.X) # TORCH (PyTorch) deep_learning_framework: TORCH -# Sets which Red Agent algo/class will be used: +# Sets which Agent class will be used. # Options are: -# "A2C" (Advantage Actor Critic) -# "PPO" (Proximal Policy Optimization) -# "HARDCODED" (Custom Agent) -# "RANDOM" (Random Action) -red_agent_identifier: PPO +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: PPO + +# Sets what view of the environment the deterministic hardcoded agent has. The default is BASIC. +# Options are: +# "BASIC" (The current observation space only) +# "FULL" (Full environment view with actions taken and reward feedback) +hard_coded_agent_view: FULL # Sets How the Action Space is defined: # "NODE" # "ACL" # "ANY" node and acl actions -action_type: NODE +action_type: ACL # Number of episodes to run per session num_episodes: 10 diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4695f2f5..f8adae25 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -8,8 +8,8 @@ from typing import Any, Dict, Final, Union, Optional import yaml from primaite import USERS_CONFIG_DIR, getLogger -from primaite.common.enums import DeepLearningFramework -from primaite.common.enums import ActionType, RedAgentIdentifier, \ +from primaite.common.enums import DeepLearningFramework, HardCodedAgentView +from primaite.common.enums import ActionType, AgentIdentifier, \ AgentFramework, SessionType, OutputVerboseLevel _LOGGER = getLogger(__name__) @@ -42,8 +42,11 @@ class TrainingConfig: deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF "The DeepLearningFramework" - red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO - "The RedAgentIdentifier" + agent_identifier: AgentIdentifier = AgentIdentifier.PPO + "The AgentIdentifier" + + hard_coded_agent_view: HardCodedAgentView = HardCodedAgentView.FULL + "The view the deterministic hard-coded agent has of the environment" action_type: ActionType = ActionType.ANY "The ActionType to use" @@ -176,10 +179,11 @@ class TrainingConfig: field_enum_map = { "agent_framework": AgentFramework, "deep_learning_framework": DeepLearningFramework, - "red_agent_identifier": RedAgentIdentifier, + "agent_identifier": AgentIdentifier, "action_type": ActionType, "session_type": SessionType, - "output_verbose_level": OutputVerboseLevel + "output_verbose_level": OutputVerboseLevel, + "hard_coded_agent_view": HardCodedAgentView } for field, enum_class in field_enum_map.items(): @@ -197,12 +201,13 @@ class TrainingConfig: """ data = self.__dict__ if json_serializable: - data["agent_framework"] = self.agent_framework.value - data["deep_learning_framework"] = self.deep_learning_framework.value - data["red_agent_identifier"] = self.red_agent_identifier.value - data["action_type"] = self.action_type.value - data["output_verbose_level"] = self.output_verbose_level.value - data["session_type"] = self.session_type.value + data["agent_framework"] = self.agent_framework.name + data["deep_learning_framework"] = self.deep_learning_framework.name + data["agent_identifier"] = self.agent_identifier.name + data["action_type"] = self.action_type.name + data["output_verbose_level"] = self.output_verbose_level.name + data["session_type"] = self.session_type.name + data["hard_coded_agent_view"] = self.hard_coded_agent_view.name return data @@ -255,7 +260,7 @@ def load( def convert_legacy_training_config_dict( legacy_config_dict: Dict[str, Any], agent_framework: AgentFramework = AgentFramework.SB3, - red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO, + agent_identifier: AgentIdentifier = AgentIdentifier.PPO, action_type: ActionType = ActionType.ANY, num_steps: int = 256, output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO @@ -266,8 +271,8 @@ def convert_legacy_training_config_dict( :param legacy_config_dict: A legacy training config dict. :param agent_framework: The agent framework to use as legacy training configs don't have agent_framework values. - :param red_agent_identifier: The red agent identifier to use as legacy - training configs don't have red_agent_identifier values. + :param agent_identifier: The red agent identifier to use as legacy + training configs don't have agent_identifier values. :param action_type: The action space type to set as legacy training configs don't have action_type values. :param num_steps: The number of steps to set as legacy training configs @@ -278,7 +283,7 @@ def convert_legacy_training_config_dict( """ config_dict = { "agent_framework": agent_framework.name, - "red_agent_identifier": red_agent_identifier.name, + "agent_identifier": agent_identifier.name, "action_type": action_type.name, "num_steps": num_steps, "output_verbose_level": output_verbose_level diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 0876f070..502069ec 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -97,7 +97,7 @@ class Primaite(Env): self.transaction_list = transaction_list # The agent in use - self.agent_identifier = self.training_config.red_agent_identifier + self.agent_identifier = self.training_config.agent_identifier # Create a dictionary to hold all the nodes self.nodes: Dict[str, NodeUnion] = {} diff --git a/src/primaite/main.py b/src/primaite/main.py index 34134ba2..100248dd 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -1,137 +1,15 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -""" -The main PrimAITE session runner module. - -TODO: This will eventually be refactored out into a proper Session class. -TODO: The passing about of session_path and timestamp_str is temporary and - will be cleaned up once we move to a proper Session class. -""" +"""The main PrimAITE session runner module.""" import argparse -import json -import time -from datetime import datetime from pathlib import Path -from typing import Final, Union -from uuid import uuid4 +from typing import Union -from stable_baselines3 import A2C, PPO -from stable_baselines3.common.evaluation import evaluate_policy -from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.ppo import MlpPolicy as PPOMlp - -from primaite import SESSIONS_DIR, getLogger -from primaite.config.training_config import TrainingConfig -from primaite.environment.primaite_env import Primaite +from primaite import getLogger from primaite.primaite_session import PrimaiteSession -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file _LOGGER = getLogger(__name__) -def run_generic(env: Primaite, config_values: TrainingConfig): - """ - Run against a generic agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - """ - for episode in range(0, config_values.num_episodes): - env.reset() - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - action = env.action_space.sample() - - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - - # Reset the environment at the end of the episode - - env.close() - - -def run_stable_baselines3_ppo( - env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str -): - """ - Run against a stable_baselines3 PPO agent. - - :param env: An instance of - :class:`~primaite.environment.primaite_env.Primaite`. - :param config_values: An instance of - :class:`~primaite.config.training_config.TrainingConfig`. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if config_values.load_agent: - try: - agent = PPO.load( - config_values.agent_load_file, - env, - verbose=0, - n_steps=config_values.num_steps, - ) - except Exception: - print( - "ERROR: Could not load agent at location: " - + config_values.agent_load_file - ) - _LOGGER.error("Could not load agent") - _LOGGER.error("Exception occured", exc_info=True) - else: - agent = PPO(PPOMlp, env, verbose=0, n_steps=config_values.num_steps) - - if config_values.session_type == "TRAINING": - # We're in a training session - print("Starting training session...") - _LOGGER.debug("Starting training session...") - for episode in range(config_values.num_episodes): - agent.learn(total_timesteps=config_values.num_steps) - _save_agent(agent, session_path, timestamp_str) - else: - # Default to being in an evaluation session - print("Starting evaluation session...") - _LOGGER.debug("Starting evaluation session...") - evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes) - - env.close() - - - - -def _save_agent(agent: OnPolicyAlgorithm, session_path: Path, timestamp_str: str): - """ - Persist an agent. - - Only works for stable baselines3 agents at present. - - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - if not isinstance(agent, OnPolicyAlgorithm): - msg = f"Can only save {OnPolicyAlgorithm} agents, got {type(agent)}." - _LOGGER.error(msg) - else: - filepath = session_path / f"agent_saved_{timestamp_str}" - agent.save(filepath) - _LOGGER.debug(f"Trained agent saved as: {filepath}") - - - - def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): """Run the PrimAITE Session. diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index a4148d12..70a18a4b 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -8,9 +8,13 @@ from uuid import uuid4 from primaite import getLogger, SESSIONS_DIR from primaite.agents.agent import AgentSessionABC +from primaite.agents.hardcoded_acl import HardCodedACLAgent +from primaite.agents.hardcoded_node import HardCodedNodeAgent from primaite.agents.rllib import RLlibAgent from primaite.agents.sb3 import SB3Agent -from primaite.common.enums import AgentFramework, RedAgentIdentifier, \ +from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, \ + RandomAgent, DummyAgent +from primaite.common.enums import AgentFramework, AgentIdentifier, \ ActionType, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig @@ -68,31 +72,66 @@ class PrimaiteSession: self.learn() def setup(self): - if self._training_config.agent_framework == AgentFramework.NONE: - if self._training_config.red_agent_identifier == RedAgentIdentifier.RANDOM: - # Stochastic Random Agent - raise NotImplementedError - - elif self._training_config.red_agent_identifier == RedAgentIdentifier.HARDCODED: + if self._training_config.agent_framework == AgentFramework.CUSTOM: + if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space - raise NotImplementedError + self._agent_session = HardCodedNodeAgent( + self._training_config_path, + self._lay_down_config_path + ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - raise NotImplementedError + self._agent_session = HardCodedACLAgent( + self._training_config_path, + self._lay_down_config_path + ) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space raise NotImplementedError else: - # Invalid RedAgentIdentifier ActionType combo - pass + # Invalid AgentIdentifier ActionType combo + raise ValueError + + elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: + if self._training_config.action_type == ActionType.NODE: + self._agent_session = DoNothingNodeAgent( + self._training_config_path, + self._lay_down_config_path + ) + + elif self._training_config.action_type == ActionType.ACL: + # Deterministic Hardcoded Agent with ACL Action Space + self._agent_session = DoNothingACLAgent( + self._training_config_path, + self._lay_down_config_path + ) + + elif self._training_config.action_type == ActionType.ANY: + # Deterministic Hardcoded Agent with ANY Action Space + raise NotImplementedError + + else: + # Invalid AgentIdentifier ActionType combo + raise ValueError + + elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: + self._agent_session = RandomAgent( + self._training_config_path, + self._lay_down_config_path + ) + elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: + self._agent_session = DummyAgent( + self._training_config_path, + self._lay_down_config_path + ) else: - # Invalid AgentFramework RedAgentIdentifier combo - pass + # Invalid AgentFramework AgentIdentifier combo + raise ValueError elif self._training_config.agent_framework == AgentFramework.SB3: # Stable Baselines3 Agent @@ -110,7 +149,7 @@ class PrimaiteSession: else: # Invalid AgentFramework - pass + raise ValueError def learn( self, diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy/new_training_config.yaml index 44897bfa..9fdf9a05 100644 --- a/tests/config/legacy/new_training_config.yaml +++ b/tests/config/legacy/new_training_config.yaml @@ -13,7 +13,7 @@ agent_framework: RLLIB # "A2C" (Advantage Actor Critic) # "HARDCODED" (Custom Agent) # "RANDOM" (Random Action) -red_agent_identifier: PPO +agent_identifier: PPO # Sets How the Action Space is defined: # "NODE" From 5a6fdf58d49596f96761e11ba25afff1523f07e3 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Tue, 20 Jun 2023 22:29:46 +0100 Subject: [PATCH 07/43] #917 - Got things working'ish --- src/primaite/agents/rllib.py | 16 ++++++++++++---- src/primaite/agents/sb3.py | 2 +- src/primaite/environment/primaite_env.py | 3 +++ 3 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 7d0cde60..b4b0ec56 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -8,10 +8,11 @@ from ray.rllib.algorithms.a2c import A2CConfig from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env - +import tensorflow as tf from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier +from primaite.common.enums import AgentFramework, AgentIdentifier, \ + DeepLearningFramework from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) @@ -115,7 +116,7 @@ class RLlibAgent(AgentSessionABC): train_batch_size=self._training_config.num_steps ) self._agent_config.framework( - framework="torch" + framework="tf" ) self._agent_config.rollouts( @@ -127,6 +128,7 @@ class RLlibAgent(AgentSessionABC): logger_creator=_custom_log_creator(self.session_path) ) + def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] @@ -154,8 +156,14 @@ class RLlibAgent(AgentSessionABC): for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() - self._agent.stop() + if self._training_config.deep_learning_framework != DeepLearningFramework.TORCH: + policy = self._agent.get_policy() + tf.compat.v1.summary.FileWriter( + self.session_path / "ray_results", + policy.get_session().graph + ) super().learn() + self._agent.stop() def evaluate( self, diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 3748b57d..073eb2fe 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -86,10 +86,10 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes - for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() + self._env.close() super().learn() diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 502069ec..4a958fa6 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -3,6 +3,7 @@ import copy import csv import logging +import time from datetime import datetime from pathlib import Path from typing import Dict, Tuple, Union, Final @@ -301,6 +302,8 @@ class Primaite(Env): done: Indicates episode is complete if True step_info: Additional information relating to this step """ + # Introduce a delay between steps + time.sleep(self.training_config.time_delay / 1000) if self.step_count == 0: print(f"Episode: {str(self.episode_count)}") From 7f1c4ce036eea72d8df2febc11357663915240c8 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Thu, 22 Jun 2023 14:10:38 +0100 Subject: [PATCH 08/43] #917 - Updated main config --- .../_package_data/training/training_config_main.yaml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 2cc29c55..0f99a501 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -5,7 +5,7 @@ # "SB3" (Stable Baselines3) # "RLLIB" (Ray RLlib) # "CUSTOM" (Custom Agent) -agent_framework: RLLIB +agent_framework: SB3 # Sets which deep learning framework will be used (by RLlib ONLY). # Default is TF (Tensorflow). @@ -35,20 +35,20 @@ hard_coded_agent_view: FULL # "NODE" # "ACL" # "ANY" node and acl actions -action_type: ACL +action_type: ANY # Number of episodes to run per session -num_episodes: 10 +num_episodes: 100 # Number of time_steps per episode num_steps: 256 # Sets how often the agent will save a checkpoint (every n time episodes). # Set to 0 if no checkpoints are required. Default is 10 -checkpoint_every_n_episodes: 5 +checkpoint_every_n_episodes: 100 # Time delay between steps (for generic agents) -time_delay: 10 +time_delay: 3 # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING From e0f3d61f6511181e861cfe4bc27d46cb18c0fba7 Mon Sep 17 00:00:00 2001 From: Brian Kanyora Date: Thu, 22 Jun 2023 15:34:13 +0100 Subject: [PATCH 09/43] feature\1522: Create random red agent behaviour. --- src/primaite/config/training_config.py | 17 +- src/primaite/environment/primaite_env.py | 173 ++++++++++++++++-- .../nodes/node_state_instruction_red.py | 17 ++ tests/config/random_agent_main_config.yaml | 96 ++++++++++ tests/test_red_random_agent_behaviour.py | 74 ++++++++ 5 files changed, 356 insertions(+), 21 deletions(-) create mode 100644 tests/config/random_agent_main_config.yaml create mode 100644 tests/test_red_random_agent_behaviour.py diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 4af36abe..6e88e7cb 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,7 +1,7 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from dataclasses import dataclass, field from pathlib import Path -from typing import Any, Dict, Final, Union, Optional +from typing import Any, Dict, Final, Optional, Union import yaml @@ -21,6 +21,9 @@ class TrainingConfig: agent_identifier: str = "STABLE_BASELINES3_A2C" "The Red Agent algo/class to be used." + red_agent_identifier: str = "RANDOM" + "Creates Random Red Agent Attacks" + action_type: ActionType = ActionType.ANY "The ActionType to use." @@ -167,8 +170,7 @@ def main_training_config_path() -> Path: return path -def load(file_path: Union[str, Path], - legacy_file: bool = False) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -213,9 +215,7 @@ def load(file_path: Union[str, Path], def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - num_steps: int = 256, - action_type: str = "ANY" + legacy_config_dict: Dict[str, Any], num_steps: int = 256, action_type: str = "ANY" ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -227,10 +227,7 @@ def convert_legacy_training_config_dict( don't have action_type values. :return: The converted training config dict. """ - config_dict = { - "num_steps": num_steps, - "action_type": action_type - } + config_dict = {"num_steps": num_steps, "action_type": action_type} for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index da235971..9161fa43 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -14,8 +14,7 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,8 +23,9 @@ from primaite.common.enums import ( NodePOLInitiator, NodePOLType, NodeType, + ObservationType, Priority, - SoftwareState, ObservationType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -35,15 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction _LOGGER = logging.getLogger(__name__) @@ -177,7 +175,6 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler - # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -238,7 +235,9 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}") + _LOGGER.info( + f"Invalid action type selected: {self.training_config.action_type}" + ) # Set up a csv to store the results of the training try: header = ["Episode", "Average Reward"] @@ -275,6 +274,10 @@ class Primaite(Env): # Does this for both live and reference nodes self.reset_environment() + # Create a random red agent to use for this episode + if self.training_config.red_agent_identifier == "RANDOM": + self.create_random_red_agent() + # Reset counters and totals self.total_reward = 0 self.step_count = 0 @@ -379,7 +382,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - #print(f" Step {self.step_count} Reward: {str(reward)}") + print(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -1033,7 +1036,6 @@ class Primaite(Env): """ self.observation_type = ObservationType[observation_info["type"]] - def get_action_info(self, action_info): """ Extracts action_info. @@ -1216,3 +1218,152 @@ class Primaite(Env): # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict + + def create_random_red_agent(self): + """Decide on random red agent for the episode to be called in env.reset().""" + + # Reset the current red iers and red node pol + self.red_iers = {} + self.red_node_pol = {} + + # Decide how many nodes become compromised + node_list = list(self.nodes.values()) + computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + max_num_nodes_compromised = len( + computers + ) # only computers can become compromised + # random select between 1 and max_num_nodes_compromised + num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1) + + # Decide which of the nodes to compromise + nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + + # For each of the nodes to be compromised decide which step they become compromised + max_step_compromised = ( + self.episode_steps // 2 + ) # always compromise in first half of episode + + # Bandwidth for all links + bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + servers = [node for node in node_list if node.node_type == NodeType.SERVER] + + for n, node in enumerate(nodes_to_be_compromised): + # 1: Use Node PoL to set node to compromised + + _id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate + _start_step = np.random.randint( + 2, max_step_compromised + 1 + ) # step compromised + _end_step = _start_step # Become compromised on 1 step + _target_node_id = node.node_id + _pol_initiator = "DIRECT" + _pol_type = NodePOLType["SERVICE"] # All computers are service nodes + pol_service_name = np.random.choice( + list(node.get_services().keys()) + ) # Random service may wish to change this, currently always TCP) + pol_protocol = pol_protocol + _pol_state = SoftwareState.COMPROMISED + is_entry_node = True # Assumes all computers in network are entry nodes + _pol_source_node_id = _pol_source_node_id + _pol_source_node_service = _pol_source_node_service + _pol_source_node_service_state = _pol_source_node_service_state + red_pol = NodeStateInstructionRed( + _id, + _start_step, + _end_step, + _target_node_id, + _pol_initiator, + _pol_type, + pol_protocol, + _pol_state, + _pol_source_node_id, + _pol_source_node_service, + _pol_source_node_service_state, + ) + + self.red_node_pol[_id] = red_pol + + # 2: Launch the attack from compromised node - set the IER + + ier_id = str(2000 + n) + # Launch the attack after node is compromised, and not right at the end of the episode + ier_start_step = np.random.randint( + _start_step + 2, int(self.episode_steps * 0.8) + ) + ier_end_step = self.episode_steps + ier_source_node_id = node.get_id() + # Randomise the load, as a percentage of a random link bandwith + ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( + bandwidths + ) + ier_protocol = pol_service_name # Same protocol as compromised node + ier_service = node.get_services()[ + pol_service_name + ] # same service as defined in the pol + ier_port = ier_service.get_port() + ier_mission_criticality = ( + 0 # Red IER will never be important to green agent success + ) + # We choose a node to attack based on the first that applies: + # a. Green IERs, select dest node of the red ier based on dest node of green IER + # b. Attack a random server that doesn't have a DENY acl rule in default config + # c. Attack a random server + possible_ier_destinations = [ + ier.get_dest_node_id() + for ier in list(self.green_iers.values()) + if ier.get_source_node_id() == node.get_id() + ] + if len(possible_ier_destinations) < 1: + for server in servers: + if not self.acl.is_blocked( + node.get_ip_address(), + server.ip_address, + ier_service, + ier_port, + ): + possible_ier_destinations.append(server.node_id) + if len(possible_ier_destinations) < 1: + # If still none found choose from all servers + possible_ier_destinations = [server.node_id for server in servers] + ier_dest = np.random.choice(possible_ier_destinations) + self.red_iers[ier_id] = IER( + ier_id, + ier_start_step, + ier_end_step, + ier_load, + ier_protocol, + ier_port, + ier_source_node_id, + ier_dest, + ier_mission_criticality, + ) + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # TODO remove duplicate red pol for same targetted service - must take into account start step + + o_pol_id = str(3000 + n) + o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched + o_pol_end_step = ( + self.episode_steps + ) # Can become compromised at any timestep after start + o_pol_node_id = ier_dest # Node effected is the one targetted by the IER + o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes + o_pol_service_name = ( + ier_protocol # Same protocol/service as the IER uses to attack + ) + o_pol_new_state = SoftwareState["OVERWHELMED"] + o_pol_entry_node = False # Assumes servers are not entry nodes + o_red_pol = NodeStateInstructionRed( + _id, + _start_step, + _end_step, + _target_node_id, + _pol_initiator, + _pol_type, + pol_protocol, + _pol_state, + _pol_source_node_id, + _pol_source_node_service, + _pol_source_node_service_state, + ) + self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 7f62fe24..9ae917e9 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -137,3 +137,20 @@ class NodeStateInstructionRed(object): The source node service state """ return self.source_node_service_state + + def __repr__(self): + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"start_step={self.start_step}, " + f"end_step={self.end_step}, " + f"target_node_id={self.target_node_id}, " + f"initiator={self.initiator}, " + f"pol_type={self.pol_type}, " + f"service_name={self.service_name}, " + f"state={self.state}, " + f"source_node_id={self.source_node_id}, " + f"source_node_service={self.source_node_service}, " + f"source_node_service_state={self.source_node_service_state}" + f")" + ) \ No newline at end of file diff --git a/tests/config/random_agent_main_config.yaml b/tests/config/random_agent_main_config.yaml new file mode 100644 index 00000000..d2d18bbc --- /dev/null +++ b/tests/config/random_agent_main_config.yaml @@ -0,0 +1,96 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: GENERIC +# +red_agent_identifier: RANDOM +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: ANY +# Number of episodes to run per session +num_episodes: 1 +# Number of time_steps per episode +num_steps: 5 +# Time delay between steps (for generic agents) +time_delay: 1 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1_000_000_000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py new file mode 100644 index 00000000..a86e32c1 --- /dev/null +++ b/tests/test_red_random_agent_behaviour.py @@ -0,0 +1,74 @@ +from datetime import time, datetime + +from primaite.environment.primaite_env import Primaite +from tests import TEST_CONFIG_ROOT +from tests.conftest import _get_temp_session_path + + +def run_generic(env, config_values): + """Run against a generic agent.""" + # Reset the environment at the start of the episode + env.reset() + for episode in range(0, config_values.num_episodes): + for step in range(0, config_values.num_steps): + # Send the observation space to the agent to get an action + # TEMP - random action for now + # action = env.blue_agent_action(obs) + # action = env.action_space.sample() + action = 0 + + # Run the simulation step on the live environment + obs, reward, done, info = env.step(action) + + # Break if done is True + if done: + break + + # Introduce a delay between steps + time.sleep(config_values.time_delay / 1000) + + # Reset the environment at the end of the episode + env.reset() + + env.close() + + +def test_random_red_agent_behaviour(): + """ + Test that hardware state is penalised at each step. + + When the initial state is OFF compared to reference state which is ON. + """ + list_of_node_instructions = [] + for i in range(2): + + """Takes a config path and returns the created instance of Primaite.""" + session_timestamp: datetime = datetime.now() + session_path = _get_temp_session_path(session_timestamp) + + timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + env = Primaite( + training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + transaction_list=[], + session_path=session_path, + timestamp_str=timestamp_str, + ) + training_config = env.training_config + training_config.num_steps = env.episode_steps + + # TOOD: This needs t be refactored to happen outside. Should be part of + # a main Session class. + if training_config.agent_identifier == "GENERIC": + run_generic(env, training_config) + all_red_actions = env.red_node_pol + list_of_node_instructions.append(all_red_actions) + + # assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1])) + print(list_of_node_instructions[0]["1"].get_start_step()) + print(list_of_node_instructions[0]["1"].get_end_step()) + print(list_of_node_instructions[0]["1"].get_target_node_id()) + print(list_of_node_instructions[1]["1"].get_start_step()) + print(list_of_node_instructions[1]["1"].get_end_step()) + print(list_of_node_instructions[1]["1"].get_target_node_id()) + assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) From 09412cb43d438ce8394bb24853092c46a28924bc Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 27 Jun 2023 12:27:57 +0100 Subject: [PATCH 10/43] 1555 - updated doc-string to make test understanding easier --- tests/test_reward.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/tests/test_reward.py b/tests/test_reward.py index c3fcdfc4..56e31ed5 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -16,17 +16,25 @@ def test_rewards_are_being_penalised_at_each_step_function(): ) """ - On different steps (of the 13 in total) these are the following rewards for config_6 which are activated: - File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3) - Hardware State: onShouldBeOff = -2 (between Steps 4 & 6) - Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9) - Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12) + The config 'one_node_states_on_off_lay_down_config.yaml' has 15 steps: + On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute. + For example, turning the nodes' file system state to CORRUPT from its original state GOOD. + As a result these are the following rewards are activated: + File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 = 3) + Hardware State: off_should_be_on = -10 * 2 (on Steps 4 - 6) + Service State: compromised_should_be_good = -20 * 2 (on Steps 7 - 9) + Software State: compromised_should_be_good = -20 * 2 (on Steps 10 - 12) - Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26 - Step Count: 13 + The Pattern of Life (PoLs) last for 2 steps, so the agent is penalised twice. + + Note: This test run inherits conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded + to do NOTHING on every step so we use Pattern of Lifes (PoLs) to change the nodes states and display that the agent + is being penalised on every step where the live network node differs from the network reference node. + + Total Reward: -10 + -10 + -10 + -10 + -20 + -20 + -20 + -20 = -120 + Step Count: 15 For the 4 steps where this occurs the average reward is: - Average Reward: 2 (26 / 13) + Average Reward: -8 (-120 / 15) """ - print("average reward", env.average_reward) assert env.average_reward == -8.0 From b8a4ede83f89a3d6b2912af8da1f0eb4be3eeb6d Mon Sep 17 00:00:00 2001 From: SunilSamra Date: Tue, 27 Jun 2023 16:59:43 +0100 Subject: [PATCH 11/43] 1555 - added specific steps to doc string --- tests/test_reward.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_reward.py b/tests/test_reward.py index 56e31ed5..b8c92274 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -20,16 +20,17 @@ def test_rewards_are_being_penalised_at_each_step_function(): On different steps, the laydown config has Pattern of Life (PoLs) which change a state of the node's attribute. For example, turning the nodes' file system state to CORRUPT from its original state GOOD. As a result these are the following rewards are activated: - File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 = 3) - Hardware State: off_should_be_on = -10 * 2 (on Steps 4 - 6) - Service State: compromised_should_be_good = -20 * 2 (on Steps 7 - 9) - Software State: compromised_should_be_good = -20 * 2 (on Steps 10 - 12) + File System State: corrupt_should_be_good = -10 * 2 (on Steps 1 & 2) + Hardware State: off_should_be_on = -10 * 2 (on Steps 4 & 5) + Service State: compromised_should_be_good = -20 * 2 (on Steps 7 & 8) + Software State: compromised_should_be_good = -20 * 2 (on Steps 10 & 11) The Pattern of Life (PoLs) last for 2 steps, so the agent is penalised twice. - Note: This test run inherits conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded - to do NOTHING on every step so we use Pattern of Lifes (PoLs) to change the nodes states and display that the agent - is being penalised on every step where the live network node differs from the network reference node. + Note: This test run inherits from conftest.py where the PrimAITE environment is ran and the blue agent is hard-coded + to do NOTHING on every step. + We use Pattern of Lifes (PoLs) to change the nodes states and display that the agent is being penalised on all steps + where the live network node differs from the network reference node. Total Reward: -10 + -10 + -10 + -10 + -20 + -20 + -20 + -20 = -120 Step Count: 15 From 9666b92caa45b96cc007ca5ecf9456e079e12bdc Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 28 Jun 2023 11:07:45 +0100 Subject: [PATCH 12/43] Attempt to add flat spaces --- scratch.py | 6 +++++ .../training/training_config_main.yaml | 9 +++++-- src/primaite/environment/observations.py | 24 +++++++++++++++---- 3 files changed, 33 insertions(+), 6 deletions(-) create mode 100644 scratch.py diff --git a/scratch.py b/scratch.py new file mode 100644 index 00000000..6bab60c1 --- /dev/null +++ b/scratch.py @@ -0,0 +1,6 @@ +from primaite.main import run + +run( + "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml", + "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml", +) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..a679400c 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -11,12 +11,17 @@ agent_identifier: STABLE_BASELINES3_A2C # "ACL" # "ANY" node and acl actions action_type: NODE +# observation space +observation_space: + # flatten: true + components: + - name: NODE_LINK_TABLE # Number of episodes to run per session -num_episodes: 10 +num_episodes: 1000 # Number of time_steps per episode num_steps: 256 # Time delay between steps (for generic agents) -time_delay: 10 +time_delay: 0 # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING # Determine whether to load an agent from file diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 9e71ef1b..e6eb533c 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -311,8 +311,13 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] + + # need to keep track of the flattened and unflattened version of the space (if there is one) self.space: spaces.Space + self.unflattened_space: spaces.Space + self.current_observation: Union[Tuple[np.ndarray], np.ndarray] + self.flatten: bool = False def update_obs(self): """Fetch fresh information about the environment.""" @@ -324,9 +329,14 @@ class ObservationsHandler: # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: self.current_observation = current_obs[0] + # If there are many compoenents, the space may need to be flattened else: - self.current_observation = tuple(current_obs) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + if self.flatten: + self.current_observation = spaces.flatten( + self.unflattened_space, tuple(current_obs) + ) + else: + self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -357,8 +367,11 @@ class ObservationsHandler: if len(component_spaces) == 1: self.space = component_spaces[0] else: - self.space = spaces.Tuple(component_spaces) - # TODO: We may need to add ability to flatten the space as not all agents support tuple spaces. + self.unflattened_space = spaces.Tuple(component_spaces) + if self.flatten: + self.space = spaces.flatten_space(spaces.Tuple(component_spaces)) + else: + self.space = self.unflattened_space @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -388,6 +401,9 @@ class ObservationsHandler: # Instantiate the handler handler = cls() + if obs_space_config.get("flatten"): + handler.flatten = True + for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] From 74821920465a61550728475b1fba3bb6ca32ae55 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 28 Jun 2023 12:01:01 +0100 Subject: [PATCH 13/43] #917 - Synced with dev and added better logging --- src/primaite/__init__.py | 47 +++++++++- src/primaite/agents/agent.py | 23 ++--- src/primaite/agents/rllib.py | 3 +- src/primaite/agents/sb3.py | 10 ++- src/primaite/agents/utils.py | 2 +- src/primaite/cli.py | 19 ++++- src/primaite/common/enums.py | 11 +-- .../training/training_config_main.yaml | 19 ++--- src/primaite/config/training_config.py | 19 ++++- src/primaite/environment/primaite_env.py | 85 ++++++++++--------- src/primaite/primaite_session.py | 12 +-- .../setup/_package_data/primaite_config.yaml | 7 +- 12 files changed, 170 insertions(+), 87 deletions(-) diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 420420f4..24815727 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -2,9 +2,12 @@ import logging import logging.config import sys -from logging import Logger, StreamHandler +from bisect import bisect +from logging import Formatter, LogRecord, StreamHandler +from logging import Logger from logging.handlers import RotatingFileHandler from pathlib import Path +from typing import Dict from typing import Final import pkg_resources @@ -68,6 +71,33 @@ Users PrimAITE Sessions are stored at: ``~/primaite/sessions``. # region Setup Logging +class _LevelFormatter(Formatter): + """ + A custom level-specific formatter. + + Credit to: https://stackoverflow.com/a/68154386 + """ + + def __init__(self, formats: Dict[int, str], **kwargs): + super().__init__() + + if "fmt" in kwargs: + raise ValueError( + "Format string must be passed to level-surrogate formatters, " + "not this one" + ) + + self.formats = sorted( + (level, Formatter(fmt, **kwargs)) for level, fmt in formats.items() + ) + + def format(self, record: LogRecord) -> str: + """Overrides ``Formatter.format``.""" + idx = bisect(self.formats, (record.levelno,), hi=len(self.formats) - 1) + level, formatter = self.formats[idx] + return formatter.format(record) + + def _log_dir() -> Path: if sys.platform == "win32": dir_path = _PLATFORM_DIRS.user_data_path / "logs" @@ -76,6 +106,16 @@ def _log_dir() -> Path: return dir_path +_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( + { + logging.DEBUG: _PRIMAITE_CONFIG["logger_format"]["DEBUG"], + logging.INFO: _PRIMAITE_CONFIG["logger_format"]["INFO"], + logging.WARNING: _PRIMAITE_CONFIG["logger_format"]["WARNING"], + logging.ERROR: _PRIMAITE_CONFIG["logger_format"]["ERROR"], + logging.CRITICAL: _PRIMAITE_CONFIG["logger_format"]["CRITICAL"] + } +) + LOG_DIR: Final[Path] = _log_dir() """The path to the app log directory as an instance of `Path` or `PosixPath`, depending on the OS.""" @@ -85,6 +125,7 @@ LOG_PATH: Final[Path] = LOG_DIR / "primaite.log" """The primaite.log file path as an instance of `Path` or `PosixPath`, depending on the OS.""" _STREAM_HANDLER: Final[StreamHandler] = StreamHandler() + _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler( filename=LOG_PATH, maxBytes=10485760, # 10MB @@ -95,8 +136,8 @@ _STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) _FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) _LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"] -_STREAM_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR)) -_FILE_HANDLER.setFormatter(logging.Formatter(_LOG_FORMAT_STR)) +_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER) +_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER) _LOGGER = logging.getLogger(__name__) diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 812072ba..5f4dac8f 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -3,11 +3,11 @@ import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Optional, Final, Dict, Union, List +from typing import Optional, Final, Dict, Union from uuid import uuid4 +import primaite from primaite import getLogger, SESSIONS_DIR -from primaite.common.enums import OutputVerboseLevel from primaite.config import lay_down_config from primaite.config import training_config from primaite.config.training_config import TrainingConfig @@ -141,14 +141,13 @@ class AgentSessionABC(ABC): @abstractmethod def _setup(self): - if self.output_verbose_level >= OutputVerboseLevel.INFO: - _LOGGER.info( - "Welcome to the Primary-level AI Training Environment " - "(PrimAITE)" - ) - _LOGGER.debug( - f"The output directory for this agent is: {self.session_path}" - ) + _LOGGER.info( + "Welcome to the Primary-level AI Training Environment " + f"(PrimAITE) (version: {primaite.__version__})" + ) + _LOGGER.info( + f"The output directory for this session is: {self.session_path}" + ) self._write_session_metadata_file() self._can_learn = True self._can_evaluate = False @@ -165,6 +164,7 @@ class AgentSessionABC(ABC): **kwargs ): if self._can_learn: + _LOGGER.info("Finished learning") _LOGGER.debug("Writing transactions") self._update_session_metadata_file() self._can_evaluate = True @@ -176,7 +176,7 @@ class AgentSessionABC(ABC): episodes: Optional[int] = None, **kwargs ): - pass + _LOGGER.info("Finished evaluation") @abstractmethod def _get_latest_checkpoint(self): @@ -260,6 +260,7 @@ class HardCodedAgentSessionABC(AgentSessionABC): # Introduce a delay between steps time.sleep(self._training_config.time_delay / 1000) self._env.close() + super().evaluate() @classmethod def load(cls): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index b4b0ec56..710225d7 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -152,7 +152,8 @@ class RLlibAgent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes - + _LOGGER.info(f"Beginning learning for {episodes} episodes @" + f" {time_steps} time steps...") for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 073eb2fe..4d2ded6b 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -69,8 +69,9 @@ class SB3Agent(AgentSessionABC): (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes) ): - self._agent.save( - self.checkpoints_path / f"sb3ppo_{episode_count}.zip") + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + self._agent.save(checkpoint_path) + _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") def _get_latest_checkpoint(self): pass @@ -86,6 +87,8 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes + _LOGGER.info(f"Beginning learning for {episodes} episodes @" + f" {time_steps} time steps...") for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() @@ -106,6 +109,9 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes + _LOGGER.info(f"Beginning evaluation for {episodes} episodes @" + f" {time_steps} time steps...") + for episode in range(episodes): obs = self._env.reset() diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index acc71590..a4eadc3b 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -6,8 +6,8 @@ from primaite.common.enums import ( NodeHardwareAction, NodeSoftwareAction, SoftwareState, + NodePOLType ) -from primaite.common.enums import NodePOLType def transform_action_node_readable(action): diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 319d643c..aa88a391 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -160,13 +160,26 @@ def setup(overwrite_existing: bool = True): @app.command() -def session(tc: str, ldc: str): +def session(tc: Optional[str] = None, ldc: Optional[str] = None): """ Run a PrimAITE session. - :param tc: The training config filepath. - :param ldc: The lay down config file path. + tc: The training config filepath. Optional. If no value is passed then + example default training config is used from: + ~/primaite/config/example_config/training/training_config_main.yaml. + + ldc: The lay down config file path. Optional. If no value is passed then + example default lay down config is used from: + ~/primaite/config/example_config/lay_down/lay_down_config_5_data_manipulation.yaml. """ from primaite.main import run + from primaite.config.training_config import main_training_config_path + from primaite.config.lay_down_config import data_manipulation_config_path + + if not tc: + tc = main_training_config_path() + + if not ldc: + ldc = data_manipulation_config_path() run(training_config_path=tc, lay_down_config_path=ldc) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 191cb782..6a93e1b5 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -83,9 +83,12 @@ class Protocol(Enum): class SessionType(Enum): """The type of PrimAITE Session to be run.""" - TRAINING = 1 - EVALUATION = 2 - BOTH = 3 + TRAIN = 1 + "Train an agent" + EVAL = 2 + "Evaluate an agent" + TRAIN_EVAL = 3 + "Train then evaluate an agent" class VerboseLevel(IntEnum): @@ -141,7 +144,6 @@ class HardCodedAgentView(Enum): class ActionType(Enum): """Action type enumeration.""" - NODE = 0 ACL = 1 ANY = 2 @@ -149,7 +151,6 @@ class ActionType(Enum): class ObservationType(Enum): """Observation type enumeration.""" - BOX = 0 MULTIDISCRETE = 1 diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 0f99a501..9cbcb702 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -38,26 +38,23 @@ hard_coded_agent_view: FULL action_type: ANY # Number of episodes to run per session -num_episodes: 100 +num_episodes: 10 # Number of time_steps per episode num_steps: 256 # Sets how often the agent will save a checkpoint (every n time episodes). # Set to 0 if no checkpoints are required. Default is 10 -checkpoint_every_n_episodes: 100 +checkpoint_every_n_episodes: 10 # Time delay between steps (for generic agents) -time_delay: 3 +time_delay: 5 -# Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING - -# Determine whether to load an agent from file -load_agent: False - -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] +# Type of session to be run. Options are: +# "TRAIN" (Trains an agent) +# "EVAL" (Evaluates an agent) +# "TRAIN_EVAL" (Trains then evaluates an agent) +session_type: TRAIN # Environment config values # The high value for the observation space diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 020d5b03..72b5523a 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -69,7 +69,7 @@ class TrainingConfig: "The delay between steps (ms). Applies to generic agents only" # file - session_type: SessionType = SessionType.TRAINING + session_type: SessionType = SessionType.TRAIN "The type of PrimAITE session to run" load_agent: str = False @@ -171,6 +171,7 @@ class TrainingConfig: file_system_scanning_limit: int = 5 "The time taken to scan the file system" + @classmethod def from_dict( cls, @@ -183,7 +184,7 @@ class TrainingConfig: "action_type": ActionType, "session_type": SessionType, "output_verbose_level": OutputVerboseLevel, - "hard_coded_agent_view": HardCodedAgentView + "hard_coded_agent_view": HardCodedAgentView, } for field, enum_class in field_enum_map.items(): @@ -211,6 +212,20 @@ class TrainingConfig: return data + def __str__(self) -> str: + tc = f"TrainingConfig(agent_framework={self.agent_framework.name}, " + if self.agent_framework is AgentFramework.RLLIB: + tc += f"deep_learning_framework=" \ + f"{self.deep_learning_framework.name}, " + tc += f"agent_identifier={self.agent_identifier.name}, " + if self.agent_identifier is AgentIdentifier.HARDCODED: + tc += f"hard_coded_agent_view={self.hard_coded_agent_view.name}, " + tc += f"action_type={self.action_type.name}, " + tc += f"observation_space={self.observation_space}, " + tc += f"num_episodes={self.num_episodes}, " + tc += f"num_steps={self.num_steps})" + return tc + def load( file_path: Union[str, Path], diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 44f576ce..5319d0f1 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -15,7 +15,8 @@ from gym import Env, spaces from matplotlib import pyplot as plt from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action +from primaite.agents.utils import is_valid_acl_action_extra, \ + is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -36,13 +37,15 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import \ + NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol +from primaite.pol.red_agent_pol import apply_red_agent_iers, \ + apply_red_agent_node_pol from primaite.transactions.transaction import Transaction from primaite.transactions.transactions_to_file import \ write_transaction_to_file @@ -61,12 +64,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -86,6 +89,7 @@ class Primaite(Env): self.training_config: TrainingConfig = training_config.load( training_config_path ) + _LOGGER.info(f"Using: {str(self.training_config)}") # Number of steps in an episode self.episode_steps = self.training_config.num_steps @@ -207,16 +211,14 @@ class Primaite(Env): plt.savefig(file_path, format="PNG") plt.clf() except Exception: - _LOGGER.error("Could not save network diagram") - _LOGGER.error("Exception occured", exc_info=True) - print("Could not save network diagram") + _LOGGER.error("Could not save network diagram", exc_info=True) # Initiate observation space self.observation_space, self.env_obs = self.init_observations() # Define Action Space - depends on action space type (Node or ACL) if self.training_config.action_type == ActionType.NODE: - _LOGGER.info("Action space type NODE selected") + _LOGGER.debug("Action space type NODE selected") # Terms (for node action space): # [0, num nodes] - node ID (0 = nothing, node ID) # [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa @@ -225,7 +227,7 @@ class Primaite(Env): self.action_dict = self.create_node_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) elif self.training_config.action_type == ActionType.ACL: - _LOGGER.info("Action space type ACL selected") + _LOGGER.debug("Action space type ACL selected") # Terms (for ACL action space): # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) # [0, 1] - Permission (0 = DENY, 1 = ALLOW) @@ -236,11 +238,11 @@ class Primaite(Env): self.action_dict = self.create_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) elif self.training_config.action_type == ActionType.ANY: - _LOGGER.info("Action space type ANY selected - Node + ACL") + _LOGGER.debug("Action space type ANY selected - Node + ACL") self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.info( + _LOGGER.error( f"Invalid action type selected: {self.training_config.action_type}" ) # Set up a csv to store the results of the training @@ -301,17 +303,14 @@ class Primaite(Env): done: Indicates episode is complete if True step_info: Additional information relating to this step """ - # Introduce a delay between steps - time.sleep(self.training_config.time_delay / 1000) if self.step_count == 0: - print(f"Episode: {str(self.episode_count)}") + _LOGGER.info(f"Episode: {str(self.episode_count)}") # TEMP done = False self.step_count += 1 self.total_step_count += 1 - # print("Episode step: " + str(self.step_count)) # Need to clear traffic on all links first for link_key, link_value in self.links.items(): @@ -322,7 +321,8 @@ class Primaite(Env): # Create a Transaction (metric) object for this step transaction = Transaction( - datetime.now(), self.agent_identifier, self.episode_count, self.step_count + datetime.now(), self.agent_identifier, self.episode_count, + self.step_count ) # Load the initial observation space into the transaction transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) @@ -352,7 +352,8 @@ class Primaite(Env): self.nodes_post_pol = copy.deepcopy(self.nodes) self.links_post_pol = copy.deepcopy(self.links) # Reference - apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL + apply_node_pol(self.nodes_reference, self.node_pol, + self.step_count) # Node PoL apply_iers( self.network_reference, self.nodes_reference, @@ -389,7 +390,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - # print(f" Step {self.step_count} Reward: {str(reward)}") + _LOGGER.debug(f" Step {self.step_count} Reward: {str(reward)}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -397,7 +398,7 @@ class Primaite(Env): # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - print(f" Average Reward: {str(self.average_reward)}") + _LOGGER.info(f" Average Reward: {str(self.average_reward)}") # Load the reward into the transaction transaction.set_reward(reward) @@ -428,6 +429,7 @@ class Primaite(Env): self.timestamp_str ) self.csv_file.close() + def init_acl(self): """Initialise the Access Control List.""" self.acl.remove_all_rules() @@ -435,9 +437,9 @@ class Primaite(Env): def output_link_status(self): """Output the link status of all links to the console.""" for link_key, link_value in self.links.items(): - print("Link ID: " + link_value.get_id()) + _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - print( + _LOGGER.debug( " Protocol: " + protocol.get_name().name + ", Load: " @@ -457,11 +459,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -529,7 +531,8 @@ class Primaite(Env): elif property_action == 1: # Patch (valid action if it's good or compromised) node.set_service_state( - self.services_list[service_index], SoftwareState.PATCHING + self.services_list[service_index], + SoftwareState.PATCHING ) else: # Node is not of Service Type @@ -589,7 +592,8 @@ class Primaite(Env): acl_rule_source = "ANY" else: node = list(self.nodes.values())[action_source_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): + if isinstance(node, ServiceNode) or isinstance(node, + ActiveNode): acl_rule_source = node.ip_address else: return @@ -598,7 +602,8 @@ class Primaite(Env): acl_rule_destination = "ANY" else: node = list(self.nodes.values())[action_destination_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): + if isinstance(node, ServiceNode) or isinstance(node, + ActiveNode): acl_rule_destination = node.ip_address else: return @@ -683,7 +688,8 @@ class Primaite(Env): :return: The observation space, initial observation (zeroed out array with the correct shape) :rtype: Tuple[spaces.Space, np.ndarray] """ - self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) + self.obs_handler = ObservationsHandler.from_config(self, + self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation @@ -727,8 +733,7 @@ class Primaite(Env): _LOGGER.error(f"Invalid item_type: {item_type}") pass - _LOGGER.info("Environment configuration loaded") - print("Environment configuration loaded") + _LOGGER.debug("Environment configuration loaded") def create_node(self, item): """ @@ -791,7 +796,8 @@ class Primaite(Env): service_protocol = service["name"] service_port = service["port"] service_state = SoftwareState[service["state"]] - node.add_service(Service(service_protocol, service_port, service_state)) + node.add_service( + Service(service_protocol, service_port, service_state)) else: # Bad formatting pass @@ -844,7 +850,8 @@ class Primaite(Env): dest_node_ref: Node = self.nodes_reference[link_destination] # Add link to network (reference) - self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name) + self.network_reference.add_edge(source_node_ref, dest_node_ref, + id=link_name) # Add link to link dictionary (reference) self.links_reference[link_name] = Link( @@ -1119,7 +1126,8 @@ class Primaite(Env): # All nodes have these parameters node_id = item["node_id"] node_class = item["node_class"] - node_hardware_state: HardwareState = HardwareState[item["hardware_state"]] + node_hardware_state: HardwareState = HardwareState[ + item["hardware_state"]] node: NodeUnion = self.nodes[node_id] node_ref = self.nodes_reference[node_id] @@ -1185,7 +1193,8 @@ class Primaite(Env): # Use MAX to ensure we get them all for node_action in range(4): for service_state in range(self.num_services): - action = [node, node_property, node_action, service_state] + action = [node, node_property, node_action, + service_state] # check to see if it's a nothing action (has no effect) if is_valid_node_action(action): actions[action_key] = action diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 70a18a4b..cd959be0 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -47,8 +47,7 @@ class PrimaiteSession: def __init__( self, training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - auto: bool = True + lay_down_config_path: Union[str, Path] ): if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) @@ -64,13 +63,8 @@ class PrimaiteSession: self._lay_down_config_path ) - self._auto: bool = auto self._agent_session: AgentSessionABC = None # noqa - if self._auto: - self.setup() - self.learn() - def setup(self): if self._training_config.agent_framework == AgentFramework.CUSTOM: if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: @@ -157,7 +151,7 @@ class PrimaiteSession: episodes: Optional[int] = None, **kwargs ): - if not self._training_config.session_type == SessionType.EVALUATION: + if not self._training_config.session_type == SessionType.EVAL: self._agent_session.learn(time_steps, episodes, **kwargs) def evaluate( @@ -166,5 +160,5 @@ class PrimaiteSession: episodes: Optional[int] = None, **kwargs ): - if not self._training_config.session_type == SessionType.TRAINING: + if not self._training_config.session_type == SessionType.TRAIN: self._agent_session.evaluate(time_steps, episodes, **kwargs) diff --git a/src/primaite/setup/_package_data/primaite_config.yaml b/src/primaite/setup/_package_data/primaite_config.yaml index 690544fb..5d469ffe 100644 --- a/src/primaite/setup/_package_data/primaite_config.yaml +++ b/src/primaite/setup/_package_data/primaite_config.yaml @@ -2,4 +2,9 @@ # Logging log_level: INFO -logger_format: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' +logger_format: + DEBUG: '%(asctime)s: %(message)s' + INFO: '%(asctime)s: %(message)s' + WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' From 1d3778f400ff1164389276dbcc34b34e1276abaa Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 28 Jun 2023 16:34:00 +0100 Subject: [PATCH 14/43] #917 - Overhauled transaction and mean reward writing. - Separated out learning outputs from evaluation outputs --- src/primaite/__init__.py | 21 +-- src/primaite/agents/agent.py | 7 +- src/primaite/agents/rllib.py | 2 - src/primaite/agents/sb3.py | 15 +- .../training/training_config_main.yaml | 2 +- src/primaite/environment/primaite_env.py | 113 +++++++-------- src/primaite/environment/reward.py | 3 - src/primaite/main.py | 1 + .../setup/_package_data/primaite_config.yaml | 15 +- src/primaite/transactions/transaction.py | 136 +++++++++++++----- .../transactions/transactions_to_file.py | 119 --------------- src/primaite/utils/session_output_writer.py | 73 ++++++++++ tests/conftest.py | 1 - 13 files changed, 258 insertions(+), 250 deletions(-) delete mode 100644 src/primaite/transactions/transactions_to_file.py create mode 100644 src/primaite/utils/session_output_writer.py diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 24815727..64857c80 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -21,6 +21,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") def _get_primaite_config(): config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" if not config_path.exists(): + config_path = Path( pkg_resources.resource_filename( "primaite", "setup/_package_data/primaite_config.yaml" @@ -36,7 +37,7 @@ def _get_primaite_config(): "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } - primaite_config["log_level"] = log_level_map[primaite_config["log_level"]] + primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]] return primaite_config @@ -108,11 +109,11 @@ def _log_dir() -> Path: _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( { - logging.DEBUG: _PRIMAITE_CONFIG["logger_format"]["DEBUG"], - logging.INFO: _PRIMAITE_CONFIG["logger_format"]["INFO"], - logging.WARNING: _PRIMAITE_CONFIG["logger_format"]["WARNING"], - logging.ERROR: _PRIMAITE_CONFIG["logger_format"]["ERROR"], - logging.CRITICAL: _PRIMAITE_CONFIG["logger_format"]["CRITICAL"] + logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], + logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], + logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], + logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], + logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"] } ) @@ -132,10 +133,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler( backupCount=9, # Max 100MB of logs encoding="utf8", ) -_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) -_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["log_level"]) +_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) +_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) -_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logger_format"] +_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"] _STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER) _FILE_HANDLER.setFormatter(_LEVEL_FORMATTER) @@ -145,7 +146,7 @@ _LOGGER.addHandler(_STREAM_HANDLER) _LOGGER.addHandler(_FILE_HANDLER) -def getLogger(name: str) -> Logger: +def getLogger(name: str) -> Logger: # noqa """ Get a PrimAITE logger. diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 5f4dac8f..05133b7e 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -64,7 +64,11 @@ class AgentSessionABC(ABC): "The session timestamp" self.session_path = _get_session_path(self.session_timestamp) "The Session path" - self.checkpoints_path = self.session_path / "checkpoints" + self.learning_path = self.session_path / "learning" + "The learning outputs path" + self.evaluation_path = self.session_path / "evaluation" + "The evaluation outputs path" + self.checkpoints_path = self.learning_path / "checkpoints" self.checkpoints_path.mkdir(parents=True, exist_ok=True) "The Session checkpoints path" @@ -205,7 +209,6 @@ class HardCodedAgentSessionABC(AgentSessionABC): self._env: Primaite = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, - transaction_list=[], session_path=self.session_path, timestamp_str=self.timestamp_str ) diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 710225d7..8a6428bb 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -21,7 +21,6 @@ def _env_creator(env_config): return Primaite( training_config_path=env_config["training_config_path"], lay_down_config_path=env_config["lay_down_config_path"], - transaction_list=env_config["transaction_list"], session_path=env_config["session_path"], timestamp_str=env_config["timestamp_str"] ) @@ -106,7 +105,6 @@ class RLlibAgent(AgentSessionABC): env_config=dict( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, - transaction_list=[], session_path=self.session_path, timestamp_str=self.timestamp_str ) diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 4d2ded6b..c183c544 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -34,7 +34,7 @@ class SB3Agent(AgentSessionABC): _LOGGER.error(msg) raise ValueError(msg) - self._tensorboard_log_path = self.session_path / "tensorboard_logs" + self._tensorboard_log_path = self.learning_path / "tensorboard_logs" self._tensorboard_log_path.mkdir(parents=True, exist_ok=True) self._setup() _LOGGER.debug( @@ -49,7 +49,6 @@ class SB3Agent(AgentSessionABC): self._env = Primaite( training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, - transaction_list=[], session_path=self.session_path, timestamp_str=self.timestamp_str ) @@ -108,10 +107,13 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes - - _LOGGER.info(f"Beginning evaluation for {episodes} episodes @" - f" {time_steps} time steps...") - + self._env.set_as_eval() + if deterministic: + deterministic_str = "deterministic" + else: + deterministic_str = "non-deterministic" + _LOGGER.info(f"Beginning {deterministic_str} evaluation for " + f"{episodes} episodes @ {time_steps} time steps...") for episode in range(episodes): obs = self._env.reset() @@ -123,6 +125,7 @@ class SB3Agent(AgentSessionABC): if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) + _LOGGER.info(f"Finished evaluation") @classmethod def load(self): diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 9cbcb702..0e0212f4 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -38,7 +38,7 @@ hard_coded_agent_view: FULL action_type: ANY # Number of episodes to run per session -num_episodes: 10 +num_episodes: 1000 # Number of time_steps per episode num_steps: 256 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5319d0f1..5b344a99 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1,10 +1,6 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy -import csv -import logging -import time -from datetime import datetime from pathlib import Path from typing import Dict, Tuple, Union, Final @@ -14,6 +10,7 @@ import yaml from gym import Env, spaces from matplotlib import pyplot as plt +from primaite import getLogger from primaite.acl.access_control_list import AccessControlList from primaite.agents.utils import is_valid_acl_action_extra, \ is_valid_node_action @@ -27,7 +24,7 @@ from primaite.common.enums import ( NodeType, ObservationType, Priority, - SoftwareState, + SoftwareState, SessionType, ) from primaite.common.service import Service from primaite.config import training_config @@ -47,11 +44,9 @@ from primaite.pol.ier import IER from primaite.pol.red_agent_pol import apply_red_agent_iers, \ apply_red_agent_node_pol from primaite.transactions.transaction import Transaction -from primaite.transactions.transactions_to_file import \ - write_transaction_to_file +from primaite.utils.session_output_writer import SessionOutputWriter -_LOGGER = logging.getLogger(__name__) -_LOGGER.setLevel(logging.INFO) +_LOGGER = getLogger(__name__) class Primaite(Env): @@ -67,7 +62,6 @@ class Primaite(Env): self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path], - transaction_list, session_path: Path, timestamp_str: str, ): @@ -76,7 +70,6 @@ class Primaite(Env): :param training_config_path: The training config filepath. :param lay_down_config_path: The lay down config filepath. - :param transaction_list: The list of transactions to populate. :param session_path: The directory path the session is writing to. :param timestamp_str: The session timestamp in the format: _. @@ -96,9 +89,6 @@ class Primaite(Env): super(Primaite, self).__init__() - # Transaction list - self.transaction_list = transaction_list - # The agent in use self.agent_identifier = self.training_config.agent_identifier @@ -245,20 +235,31 @@ class Primaite(Env): _LOGGER.error( f"Invalid action type selected: {self.training_config.action_type}" ) - # Set up a csv to store the results of the training - try: - header = ["Episode", "Average Reward"] - file_name = f"average_reward_per_episode_{timestamp_str}.csv" - file_path = session_path / file_name - self.csv_file = open(file_path, "w", encoding="UTF8", newline="") - self.csv_writer = csv.writer(self.csv_file) - self.csv_writer.writerow(header) - except Exception: - _LOGGER.error( - "Could not create csv file to hold average reward per episode" - ) - _LOGGER.error("Exception occured", exc_info=True) + self.episode_av_reward_writer = SessionOutputWriter( + self, + transaction_writer=False, + learning_session=True + ) + self.transaction_writer = SessionOutputWriter( + self, + transaction_writer=True, + learning_session=True + ) + + def set_as_eval(self): + """Set the writers to write to eval directories.""" + self.episode_av_reward_writer = SessionOutputWriter( + self, + transaction_writer=False, + learning_session=False + ) + self.transaction_writer = SessionOutputWriter( + self, + transaction_writer=True, + learning_session=False + ) + self.episode_count = 0 def reset(self): """ @@ -267,12 +268,14 @@ class Primaite(Env): Returns: Environment observation space (reset) """ - csv_data = self.episode_count, self.average_reward - self.csv_writer.writerow(csv_data) + if self.episode_count > 0: + csv_data = self.episode_count, self.average_reward + self.episode_av_reward_writer.write(csv_data) self.episode_count += 1 - # Don't need to reset links, as they are cleared and recalculated every step + # Don't need to reset links, as they are cleared and recalculated every + # step # Clear the ACL self.init_acl() @@ -303,12 +306,8 @@ class Primaite(Env): done: Indicates episode is complete if True step_info: Additional information relating to this step """ - if self.step_count == 0: - _LOGGER.info(f"Episode: {str(self.episode_count)}") - # TEMP done = False - self.step_count += 1 self.total_step_count += 1 @@ -321,13 +320,16 @@ class Primaite(Env): # Create a Transaction (metric) object for this step transaction = Transaction( - datetime.now(), self.agent_identifier, self.episode_count, + self.agent_identifier, + self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) + transaction.obs_space_pre = copy.deepcopy(self.env_obs) # Load the action space into the transaction - transaction.set_action_space(copy.deepcopy(action)) + transaction.action_space = copy.deepcopy(action) + + initial_nodes = copy.deepcopy(self.nodes) # 1. Implement Blue Action self.interpret_action_and_apply(action) @@ -381,7 +383,7 @@ class Primaite(Env): # 5. Calculate reward signal (for RL) reward = calculate_reward_function( - self.nodes_post_pol, + initial_nodes, self.nodes_post_red, self.nodes_reference, self.green_iers, @@ -390,17 +392,22 @@ class Primaite(Env): self.step_count, self.training_config, ) - _LOGGER.debug(f" Step {self.step_count} Reward: {str(reward)}") + _LOGGER.debug( + f"Episode: {self.episode_count}, " + f"Step {self.step_count}, " + f"Reward: {reward}" + ) self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count - if self.training_config.session_type == "EVALUATION": + if self.training_config.session_type is SessionType.EVAL: # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - _LOGGER.info(f" Average Reward: {str(self.average_reward)}") + _LOGGER.info(f"Episode: {self.episode_count}, " + f"Average Reward: {self.average_reward}") # Load the reward into the transaction - transaction.set_reward(reward) + transaction.reward = reward # 6. Output Verbose # self.output_link_status() @@ -408,28 +415,14 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() # Load the new observation space into the transaction - transaction.set_obs_space_post(copy.deepcopy(self.env_obs)) + transaction.obs_space_post = copy.deepcopy(self.env_obs) - # 8. Add the transaction to the list of transactions - self.transaction_list.append(copy.deepcopy(transaction)) + # Write transaction to file + self.transaction_writer.write(transaction) # Return return self.env_obs, reward, done, self.step_info - def close(self): - self.__close__() - - def __close__(self): - """ - Override close function - """ - write_transaction_to_file( - self.transaction_list, - self.session_path, - self.timestamp_str - ) - self.csv_file.close() - def init_acl(self): """Initialise the Access Control List.""" self.acl.remove_all_rules() @@ -467,7 +460,7 @@ class Primaite(Env): ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: - logging.error("Invalid action type found") + _LOGGER.error("Invalid action type found") def apply_actions_to_nodes(self, _action): """ diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 1a1a0770..00e45fa3 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -85,9 +85,6 @@ def calculate_reward_function( ) if live_blocked and not reference_blocked: - _LOGGER.debug( - f"Applying reward of {ier_reward} because IER {ier_key} is blocked" - ) reward_value += ier_reward elif live_blocked and reference_blocked: _LOGGER.debug( diff --git a/src/primaite/main.py b/src/primaite/main.py index 5aba68ef..3c0f93b3 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -20,6 +20,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, session.setup() session.learn() + session.evaluate() if __name__ == "__main__": diff --git a/src/primaite/setup/_package_data/primaite_config.yaml b/src/primaite/setup/_package_data/primaite_config.yaml index 5d469ffe..1dd8775b 100644 --- a/src/primaite/setup/_package_data/primaite_config.yaml +++ b/src/primaite/setup/_package_data/primaite_config.yaml @@ -1,10 +1,11 @@ # The main PrimAITE application config file # Logging -log_level: INFO -logger_format: - DEBUG: '%(asctime)s: %(message)s' - INFO: '%(asctime)s: %(message)s' - WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' - ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' - CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' +logging: + log_level: INFO + logger_format: + DEBUG: '%(asctime)s: %(message)s' + INFO: '%(asctime)s: %(message)s' + WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index a4ce48e3..6e5ba5f0 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -1,57 +1,115 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """The Transaction class.""" +from datetime import datetime +from typing import List, Tuple class Transaction(object): """Transaction class.""" - def __init__(self, _timestamp, _agent_identifier, _episode_number, _step_number): + def __init__( + self, + agent_identifier, + episode_number, + step_number + ): """ - Init. + Transaction constructor. - Args: - _timestamp: The time this object was created - _agent_identifier: An identifier for the agent in use - _episode_number: The episode number - _step_number: The step number + :param agent_identifier: An identifier for the agent in use + :param episode_number: The episode number + :param step_number: The step number """ - self.timestamp = _timestamp - self.agent_identifier = _agent_identifier - self.episode_number = _episode_number - self.step_number = _step_number + self.timestamp = datetime.now() + "The datetime of the transaction" + self.agent_identifier = agent_identifier + self.episode_number = episode_number + "The episode number" + self.step_number = step_number + "The step number" + self.obs_space_pre = None + "The observation space before any actions are taken" + self.obs_space_post = None + "The observation space after any actions are taken" + self.reward = None + "The reward value" + self.action_space = None + "The action space invoked by the agent" - def set_obs_space_pre(self, _obs_space_pre): - """ - Sets the observation space (pre). + def as_csv_data(self) -> Tuple[List, List]: + if isinstance(self.action_space, int): + action_length = self.action_space + else: + action_length = self.action_space.size + obs_shape = self.obs_space_post.shape + obs_assets = self.obs_space_post.shape[0] + if len(obs_shape) == 1: + # A bit of a workaround but I think the way transactions are + # written will change soon + obs_features = 1 + else: + obs_features = self.obs_space_post.shape[1] - Args: - _obs_space_pre: The observation space before any actions are taken - """ - self.obs_space_pre = _obs_space_pre + # Create the action space headers array + action_header = [] + for x in range(action_length): + action_header.append("AS_" + str(x)) - def set_obs_space_post(self, _obs_space_post): - """ - Sets the observation space (post). + # Create the observation space headers array + obs_header_initial = [] + obs_header_new = [] + for x in range(obs_assets): + for y in range(obs_features): + obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) + obs_header_new.append("OSN_" + str(x) + "_" + str(y)) - Args: - _obs_space_post: The observation space after any actions are taken - """ - self.obs_space_post = _obs_space_post + # Open up a csv file + header = ["Timestamp", "Episode", "Step", "Reward"] + header = header + action_header + obs_header_initial + obs_header_new - def set_reward(self, _reward): - """ - Sets the reward. + row = [ + str(self.timestamp), + str(self.episode_number), + str(self.step_number), + str(self.reward), + ] + row = ( + row + + _turn_action_space_to_array(self.action_space) + + _turn_obs_space_to_array(self.obs_space_pre, obs_assets, + obs_features) + + _turn_obs_space_to_array(self.obs_space_post, obs_assets, + obs_features) + ) + return header, row - Args: - _reward: The reward value - """ - self.reward = _reward - def set_action_space(self, _action_space): - """ - Sets the action space. +def _turn_action_space_to_array(action_space) -> List[str]: + """ + Turns action space into a string array so it can be saved to csv. - Args: - _action_space: The action space invoked by the agent - """ - self.action_space = _action_space + :param action_space: The action space + :return: The action space as an array of strings + """ + if isinstance(action_space, list): + return [str(i) for i in action_space] + else: + return [str(action_space)] + + +def _turn_obs_space_to_array(obs_space, obs_assets, obs_features) -> List[str]: + """ + Turns observation space into a string array so it can be saved to csv. + + :param obs_space: The observation space + :param obs_assets: The number of assets (i.e. nodes or links) in the + observation space + :param obs_features: The number of features associated with the asset + :return: The observation space as an array of strings + """ + return_array = [] + for x in range(obs_assets): + for y in range(obs_features): + return_array.append(str(obs_space[x][y])) + + return return_array diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py deleted file mode 100644 index ed7a8f1c..00000000 --- a/src/primaite/transactions/transactions_to_file.py +++ /dev/null @@ -1,119 +0,0 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -"""Writes the Transaction log list out to file for evaluation to utilse.""" - -import csv -from pathlib import Path - -import numpy as np - -from primaite import getLogger - -_LOGGER = getLogger(__name__) - - -def turn_action_space_to_array(_action_space): - """ - Turns action space into a string array so it can be saved to csv. - - Args: - _action_space: The action space. - """ - if isinstance(_action_space, list): - return [str(i) for i in _action_space] - else: - return [str(_action_space)] - - -def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): - """ - Turns observation space into a string array so it can be saved to csv. - - Args: - _obs_space: The observation space - _obs_assets: The number of assets (i.e. nodes or links) in the observation space - _obs_features: The number of features associated with the asset - """ - return_array = [] - for x in range(_obs_assets): - for y in range(_obs_features): - return_array.append(str(_obs_space[x][y])) - - return return_array - - -def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str): - """ - Writes transaction logs to file to support training evaluation. - - :param transaction_list: The list of transactions from all steps and all - episodes. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - # Get the first transaction and use it to determine the makeup of the - # observation space and action space - # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action - # space as "AS_1" - # This will be tied into the PrimAITE Use Case so that they make sense - - template_transation = transaction_list[0] - if isinstance(template_transation.action_space, int): - action_length = template_transation.action_space - else: - action_length = template_transation.action_space.size - obs_shape = template_transation.obs_space_post.shape - obs_assets = template_transation.obs_space_post.shape[0] - if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - obs_features = 1 - else: - obs_features = template_transation.obs_space_post.shape[1] - - # Create the action space headers array - action_header = [] - for x in range(action_length): - action_header.append("AS_" + str(x)) - - # Create the observation space headers array - obs_header_initial = [] - obs_header_new = [] - for x in range(obs_assets): - for y in range(obs_features): - obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) - obs_header_new.append("OSN_" + str(x) + "_" + str(y)) - - # Open up a csv file - header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_header_initial + obs_header_new - - try: - filename = session_path / f"all_transactions_{timestamp_str}.csv" - _LOGGER.debug(f"Saving transaction logs: {filename}") - csv_file = open(filename, "w", encoding="UTF8", newline="") - csv_writer = csv.writer(csv_file) - csv_writer.writerow(header) - - for transaction in transaction_list: - csv_data = [ - str(transaction.timestamp), - str(transaction.episode_number), - str(transaction.step_number), - str(transaction.reward), - ] - csv_data = ( - csv_data - + turn_action_space_to_array(transaction.action_space) - + turn_obs_space_to_array( - transaction.obs_space_pre, obs_assets, obs_features - ) - + turn_obs_space_to_array( - transaction.obs_space_post, obs_assets, obs_features - ) - ) - csv_writer.writerow(csv_data) - - csv_file.close() - _LOGGER.debug("Finished writing transactions") - except Exception: - _LOGGER.error("Could not save the transaction file", exc_info=True) diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py new file mode 100644 index 00000000..308e1fb3 --- /dev/null +++ b/src/primaite/utils/session_output_writer.py @@ -0,0 +1,73 @@ +import csv +from logging import Logger +from typing import List, Final, IO, Union, Tuple +from typing import TYPE_CHECKING + +from primaite import getLogger +from primaite.transactions.transaction import Transaction + +if TYPE_CHECKING: + from primaite.environment.primaite_env import Primaite + +_LOGGER: Logger = getLogger(__name__) + + +class SessionOutputWriter: + _AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [ + "Episode", "Average Reward" + ] + + def __init__( + self, + env: "Primaite", + transaction_writer: bool = False, + learning_session: bool = True + ): + self._env = env + self.transaction_writer = transaction_writer + self.learning_session = learning_session + + if self.transaction_writer: + fn = f"all_transactions_{self._env.timestamp_str}.csv" + else: + fn = f"average_reward_per_episode_{self._env.timestamp_str}.csv" + + if self.learning_session: + self._csv_file_path = self._env.session_path / "learning" / fn + else: + self._csv_file_path = self._env.session_path / "evaluation" / fn + + self._csv_file_path.parent.mkdir(exist_ok=True, parents=True) + + self._csv_file = None + self._csv_writer = None + + self._first_write: bool = True + + def _init_csv_writer(self): + self._csv_file = open( + self._csv_file_path, "w", encoding="UTF8", newline="" + ) + + self._csv_writer = csv.writer(self._csv_file) + + def __del__(self): + if self._csv_file: + self._csv_file.close() + _LOGGER.info(f"Finished writing file: {self._csv_file_path}") + + def write( + self, + data: Union[Tuple, Transaction] + ): + if isinstance(data, Transaction): + header, data = data.as_csv_data() + else: + header = self._AV_REWARD_PER_EPISODE_HEADER + + if self._first_write: + self._init_csv_writer() + self._csv_writer.writerow(header) + self._first_write = False + + self._csv_writer.writerow(data) diff --git a/tests/conftest.py b/tests/conftest.py index 1bad5db0..945d23f0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,7 +37,6 @@ def _get_primaite_env_from_config( env = Primaite( training_config_path=training_config_path, lay_down_config_path=lay_down_config_path, - transaction_list=[], session_path=session_path, timestamp_str=timestamp_str, ) From 7f912df383e649908071b0a9e7fdccb3bf6291ed Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Wed, 28 Jun 2023 19:54:00 +0100 Subject: [PATCH 15/43] #917 - Began the process of reloading existing agents into the session --- src/primaite/agents/agent.py | 88 ++++++++++++++++--- src/primaite/agents/sb3.py | 4 + .../training/training_config_main.yaml | 2 +- src/primaite/environment/primaite_env.py | 6 +- 4 files changed, 82 insertions(+), 18 deletions(-) diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 05133b7e..f545a3cb 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -1,3 +1,4 @@ +from __future__ import annotations import json import time from abc import ABC, abstractmethod @@ -6,6 +7,8 @@ from pathlib import Path from typing import Optional, Final, Dict, Union from uuid import uuid4 +import yaml + import primaite from primaite import getLogger, SESSIONS_DIR from primaite.config import lay_down_config @@ -58,23 +61,34 @@ class AgentSessionABC(ABC): self._agent = None self._can_learn: bool = False self._can_evaluate: bool = False + self.is_eval = False self._uuid = str(uuid4()) self.session_timestamp: datetime = datetime.now() "The session timestamp" self.session_path = _get_session_path(self.session_timestamp) "The Session path" - self.learning_path = self.session_path / "learning" - "The learning outputs path" - self.evaluation_path = self.session_path / "evaluation" - "The evaluation outputs path" - self.checkpoints_path = self.learning_path / "checkpoints" self.checkpoints_path.mkdir(parents=True, exist_ok=True) - "The Session checkpoints path" - self.timestamp_str = self.session_timestamp.strftime( - "%Y-%m-%d_%H-%M-%S") - "The session timestamp as a string" + @property + def timestamp_str(self) -> str: + """The session timestamp as a string.""" + return self.session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + + @property + def learning_path(self) -> Path: + """The learning outputs path.""" + return self.session_path / "learning" + + @property + def evaluation_path(self) -> Path: + """The evaluation outputs path.""" + return self.session_path / "evaluation" + + @property + def checkpoints_path(self) -> Path: + """The Session checkpoints path.""" + return self.learning_path / "checkpoints" @property def uuid(self): @@ -104,8 +118,14 @@ class AgentSessionABC(ABC): "uuid": self.uuid, "start_datetime": self.session_timestamp.isoformat(), "end_datetime": None, - "total_episodes": None, - "total_time_steps": None, + "learning": { + "total_episodes": None, + "total_time_steps": None + }, + "evaluation": { + "total_episodes": None, + "total_time_steps": None + }, "env": { "training_config": self._training_config.to_dict( json_serializable=True @@ -134,8 +154,13 @@ class AgentSessionABC(ABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._env.episode_count - metadata_dict["total_time_steps"] = self._env.total_step_count + + if not self.is_eval: + metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa + else: + metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -172,6 +197,7 @@ class AgentSessionABC(ABC): _LOGGER.debug("Writing transactions") self._update_session_metadata_file() self._can_evaluate = True + self.is_eval = False @abstractmethod def evaluate( @@ -180,6 +206,7 @@ class AgentSessionABC(ABC): episodes: Optional[int] = None, **kwargs ): + self.is_eval = True _LOGGER.info("Finished evaluation") @abstractmethod @@ -188,7 +215,40 @@ class AgentSessionABC(ABC): @classmethod @abstractmethod - def load(cls): + def load(cls, path: Union[str, Path]) -> AgentSessionABC: + if not isinstance(path, Path): + path = Path(path) + + if path.exists(): + # Unpack the session_metadata.json file + md_file = path / "session_metadata.json" + with open(md_file, "r") as file: + md_dict = json.load(file) + + # Create a temp directory and dump the training and lay down + # configs into it + temp_dir = path / ".temp" + temp_dir.mkdir(exist_ok=True) + + temp_tc = temp_dir / "tc.yaml" + with open(temp_tc, "w") as file: + yaml.dump(md_dict["env"]["training_config"], file) + + temp_ldc = temp_dir / "ldc.yaml" + with open(temp_ldc, "w") as file: + yaml.dump(md_dict["env"]["lay_down_config"], file) + + agent = cls(temp_tc, temp_ldc) + + agent.session_path = path + + return agent + + else: + # Session path does not exist + msg = f"Failed to load PrimAITE Session, path does not exist: {path}" + _LOGGER.error(msg) + raise FileNotFoundError(msg) pass @abstractmethod diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index c183c544..328e6286 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -44,6 +44,8 @@ class SB3Agent(AgentSessionABC): f"{self._training_config.agent_identifier}" ) + self.is_eval = False + def _setup(self): super()._setup() self._env = Primaite( @@ -86,6 +88,7 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes + self.is_eval = False _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): @@ -108,6 +111,7 @@ class SB3Agent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes self._env.set_as_eval() + self.is_eval = True if deterministic: deterministic_str = "deterministic" else: diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 0e0212f4..3cccbcae 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -35,7 +35,7 @@ hard_coded_agent_view: FULL # "NODE" # "ACL" # "ANY" node and acl actions -action_type: ANY +action_type: NODE # Number of episodes to run per session num_episodes: 1000 diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5b344a99..e43dc8a5 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -260,6 +260,8 @@ class Primaite(Env): learning_session=False ) self.episode_count = 0 + self.step_count = 0 + self.total_step_count = 0 def reset(self): """ @@ -329,8 +331,6 @@ class Primaite(Env): # Load the action space into the transaction transaction.action_space = copy.deepcopy(action) - initial_nodes = copy.deepcopy(self.nodes) - # 1. Implement Blue Action self.interpret_action_and_apply(action) # Take snapshots of nodes and links @@ -383,7 +383,7 @@ class Primaite(Env): # 5. Calculate reward signal (for RL) reward = calculate_reward_function( - initial_nodes, + self.nodes_post_pol, self.nodes_post_red, self.nodes_reference, self.green_iers, From f61d50a96f88c6cd18ab47c9bab1fb52982d4b35 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Thu, 29 Jun 2023 15:03:11 +0100 Subject: [PATCH 16/43] #1522: fixing create random red agent function --- .gitignore | 2 + src/primaite/environment/primaite_env.py | 149 ++++++++++++----------- 2 files changed, 78 insertions(+), 73 deletions(-) diff --git a/.gitignore b/.gitignore index eed6c903..5adbdc57 100644 --- a/.gitignore +++ b/.gitignore @@ -137,3 +137,5 @@ dmypy.json # Cython debug symbols cython_debug/ + +.idea/ \ No newline at end of file diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index c3d408d2..9ac3d8e6 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -9,6 +9,7 @@ from typing import Dict, Tuple, Union import networkx as nx import numpy as np +import uuid as uuid import yaml from gym import Env, spaces from matplotlib import pyplot as plt @@ -58,12 +59,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -275,8 +276,8 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - if self.training_config.red_agent_identifier == "RANDOM": - self.create_random_red_agent() + # if self.training_config.red_agent_identifier == "RANDOM": + # self.create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -380,6 +381,7 @@ class Primaite(Env): self.nodes_post_pol, self.nodes_post_red, self.nodes_reference, + self.green_iers, self.green_iers_reference, self.red_iers, self.step_count, @@ -445,11 +447,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -1247,14 +1249,17 @@ class Primaite(Env): computers ) # only computers can become compromised # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised + 1) + num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised) # Decide which of the nodes to compromise nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + # choose a random compromise node to be source of attacks + source_node = np.random.choice(nodes_to_be_compromised, 1)[0] + # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1264,57 +1269,50 @@ class Primaite(Env): for n, node in enumerate(nodes_to_be_compromised): # 1: Use Node PoL to set node to compromised - _id = str(1000 + n) # doesn't really matter, make sure it doesn't duplicate + _id = str(uuid.uuid4()) _start_step = np.random.randint( 2, max_step_compromised + 1 ) # step compromised - _end_step = _start_step # Become compromised on 1 step - _target_node_id = node.node_id - _pol_initiator = "DIRECT" - _pol_type = NodePOLType["SERVICE"] # All computers are service nodes pol_service_name = np.random.choice( - list(node.get_services().keys()) - ) # Random service may wish to change this, currently always TCP) - pol_protocol = pol_protocol - _pol_state = SoftwareState.COMPROMISED - is_entry_node = True # Assumes all computers in network are entry nodes - _pol_source_node_id = _pol_source_node_id - _pol_source_node_service = _pol_source_node_service - _pol_source_node_service_state = _pol_source_node_service_state + list(node.services.keys()) + ) + + source_node_service = np.random.choice( + list(source_node.services.values()) + ) + red_pol = NodeStateInstructionRed( - _id, - _start_step, - _end_step, - _target_node_id, - _pol_initiator, - _pol_type, - pol_protocol, - _pol_state, - _pol_source_node_id, - _pol_source_node_service, - _pol_source_node_service_state, + _id=_id, + _start_step=_start_step, + _end_step=_start_step, # only run for 1 step + _target_node_id=node.node_id, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=pol_service_name, + _pol_state=SoftwareState.COMPROMISED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state ) self.red_node_pol[_id] = red_pol # 2: Launch the attack from compromised node - set the IER - ier_id = str(2000 + n) + ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode ier_start_step = np.random.randint( _start_step + 2, int(self.episode_steps * 0.8) ) ier_end_step = self.episode_steps - ier_source_node_id = node.get_id() + # Randomise the load, as a percentage of a random link bandwith ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( bandwidths ) ier_protocol = pol_service_name # Same protocol as compromised node - ier_service = node.get_services()[ - pol_service_name - ] # same service as defined in the pol - ier_port = ier_service.get_port() + ier_service = node.services[pol_service_name] + ier_port = ier_service.port ier_mission_criticality = ( 0 # Red IER will never be important to green agent success ) @@ -1325,15 +1323,15 @@ class Primaite(Env): possible_ier_destinations = [ ier.get_dest_node_id() for ier in list(self.green_iers.values()) - if ier.get_source_node_id() == node.get_id() + if ier.get_source_node_id() == node.node_id ] if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.get_ip_address(), - server.ip_address, - ier_service, - ier_port, + node.get_ip_address(), + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: @@ -1347,37 +1345,42 @@ class Primaite(Env): ier_load, ier_protocol, ier_port, - ier_source_node_id, + node.node_id, ier_dest, ier_mission_criticality, ) - # 3: Make sure the targetted node can be set to overwhelmed - with node pol - # TODO remove duplicate red pol for same targetted service - must take into account start step + overwhelm_pol = red_pol + overwhelm_pol.id = str(uuid.uuid4()) + overwhelm_pol.end_step = self.episode_steps - o_pol_id = str(3000 + n) - o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched - o_pol_end_step = ( - self.episode_steps - ) # Can become compromised at any timestep after start - o_pol_node_id = ier_dest # Node effected is the one targetted by the IER - o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes - o_pol_service_name = ( - ier_protocol # Same protocol/service as the IER uses to attack - ) - o_pol_new_state = SoftwareState["OVERWHELMED"] - o_pol_entry_node = False # Assumes servers are not entry nodes + + # 3: Make sure the targetted node can be set to overwhelmed - with node pol + # # TODO remove duplicate red pol for same targetted service - must take into account start step + # + o_pol_id = str(uuid.uuid4()) + # o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched + # o_pol_end_step = ( + # self.episode_steps + # ) # Can become compromised at any timestep after start + # o_pol_node_id = ier_dest # Node effected is the one targetted by the IER + # o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes + # o_pol_service_name = ( + # ier_protocol # Same protocol/service as the IER uses to attack + # ) + # o_pol_new_state = SoftwareState["OVERWHELMED"] + # o_pol_entry_node = False # Assumes servers are not entry nodes o_red_pol = NodeStateInstructionRed( - _id, - _start_step, - _end_step, - _target_node_id, - _pol_initiator, - _pol_type, - pol_protocol, - _pol_state, - _pol_source_node_id, - _pol_source_node_service, - _pol_source_node_service_state, + _id=o_pol_id, + _start_step=ier_start_step, + _end_step=self.episode_steps, + _target_node_id=ier_dest, + _pol_initiator="DIRECT", + _pol_type=NodePOLType["SERVICE"], + pol_protocol=ier_protocol, + _pol_state=SoftwareState.OVERWHELMED, + _pol_source_node_id=source_node.node_id, + _pol_source_node_service=source_node_service.name, + _pol_source_node_service_state=source_node_service.software_state ) self.red_node_pol[o_pol_id] = o_red_pol From c77fde3dd33489647feede79d86f0483cc1259c1 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 29 Jun 2023 15:26:07 +0100 Subject: [PATCH 17/43] Fix observation representation in transactions --- src/primaite/environment/observations.py | 149 +++++++++++++++--- src/primaite/environment/primaite_env.py | 5 +- src/primaite/main.py | 1 + .../transactions/transactions_to_file.py | 54 ++----- 4 files changed, 150 insertions(+), 59 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index e6eb533c..023c5f30 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -29,6 +29,7 @@ class AbstractObservationComponent(ABC): self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? + self.structure: list[str] return NotImplemented @abstractmethod @@ -36,6 +37,11 @@ class AbstractObservationComponent(ABC): """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented + @abstractmethod + def generate_structure(self) -> List[str]: + """Return a list of labels for the components of the flattened observation space.""" + return NotImplemented + class NodeLinkTable(AbstractObservationComponent): """Table with nodes and links as rows and hardware/software status as cols. @@ -79,6 +85,8 @@ class NodeLinkTable(AbstractObservationComponent): # 3. Initialise Observation with zeroes self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -131,6 +139,40 @@ class NodeLinkTable(AbstractObservationComponent): protocol_index += 1 item_index += 1 + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + nodes = self.env.nodes.values() + links = self.env.links.values() + + structure = [] + + for i, node in enumerate(nodes): + node_id = node.node_id + node_labels = [ + f"node_{node_id}_id", + f"node_{node_id}_hardware_status", + f"node_{node_id}_os_status", + f"node_{node_id}_fs_status", + ] + for j, serv in enumerate(self.env.services_list): + node_labels.append(f"node_{node_id}_service_{serv}_status") + + structure.extend(node_labels) + + for i, link in enumerate(links): + link_id = link.id + link_labels = [ + f"link_{link_id}_id", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + f"link_{link_id}_n/a", + ] + for j, serv in enumerate(self.env.services_list): + link_labels.append(f"node_{node_id}_service_{serv}_load") + + structure.extend(link_labels) + return structure + class NodeStatuses(AbstractObservationComponent): """Flat list of nodes' hardware, OS, file system, and service states. @@ -179,6 +221,7 @@ class NodeStatuses(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() def update(self): """Update the observation based on current environment state. @@ -205,6 +248,30 @@ class NodeStatuses(AbstractObservationComponent): ) self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + services = self.env.services_list + + structure = [] + for _, node in self.env.nodes.items(): + node_id = node.node_id + structure.append(f"node_{node_id}_hardware_state_NONE") + for state in HardwareState: + structure.append(f"node_{node_id}_hardware_state_{state.name}") + structure.append(f"node_{node_id}_software_state_NONE") + for state in SoftwareState: + structure.append(f"node_{node_id}_software_state_{state.name}") + structure.append(f"node_{node_id}_file_system_state_NONE") + for state in FileSystemState: + structure.append(f"node_{node_id}_file_system_state_{state.name}") + for service in services: + structure.append(f"node_{node_id}_service_{service}_state_NONE") + for state in SoftwareState: + structure.append( + f"node_{node_id}_service_{service}_state_{state.name}" + ) + return structure + class LinkTrafficLevels(AbstractObservationComponent): """Flat list of traffic levels encoded into banded categories. @@ -268,6 +335,8 @@ class LinkTrafficLevels(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) + self.structure = self.generate_structure() + def update(self): """Update the observation based on current environment state. @@ -295,6 +364,21 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation[:] = obs + def generate_structure(self): + """Return a list of labels for the components of the flattened observation space.""" + structure = [] + for _, link in self.env.links.items(): + link_id = link.id + if self._combine_service_traffic: + protocols = ["overall"] + else: + protocols = [protocol.name for protocol in link.protocol_list] + + for p in protocols: + for i in range(self._quantisation_levels): + structure.append(f"link_{link_id}_{p}_traffic_level_{i}") + return structure + class ObservationsHandler: """Component-based observation space handler. @@ -312,11 +396,15 @@ class ObservationsHandler: def __init__(self): self.registered_obs_components: List[AbstractObservationComponent] = [] - # need to keep track of the flattened and unflattened version of the space (if there is one) - self.space: spaces.Space - self.unflattened_space: spaces.Space + # internal the observation space (unflattened version of space if flatten=True) + self._space: spaces.Space + # flattened version of the observation space + self._flat_space: spaces.Space + + self._observation: Union[Tuple[np.ndarray], np.ndarray] + # used for transactions and when flatten=true + self._flat_observation: np.ndarray - self.current_observation: Union[Tuple[np.ndarray], np.ndarray] self.flatten: bool = False def update_obs(self): @@ -326,17 +414,11 @@ class ObservationsHandler: obs.update() current_obs.append(obs.current_observation) - # If there is only one component, don't use a tuple, just pass through that component's obs. if len(current_obs) == 1: - self.current_observation = current_obs[0] - # If there are many compoenents, the space may need to be flattened + self._observation = current_obs[0] else: - if self.flatten: - self.current_observation = spaces.flatten( - self.unflattened_space, tuple(current_obs) - ) - else: - self.current_observation = tuple(current_obs) + self._observation = tuple(current_obs) + self._flat_observation = spaces.flatten(self._space, self._observation) def register(self, obs_component: AbstractObservationComponent): """Add a component for this handler to track. @@ -363,15 +445,28 @@ class ObservationsHandler: for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) - # If there is only one component, don't use a tuple space, just pass through that component's space. + # if there are multiple components, build a composite tuple space if len(component_spaces) == 1: - self.space = component_spaces[0] + self._space = component_spaces[0] else: - self.unflattened_space = spaces.Tuple(component_spaces) - if self.flatten: - self.space = spaces.flatten_space(spaces.Tuple(component_spaces)) - else: - self.space = self.unflattened_space + self._space = spaces.Tuple(component_spaces) + self._flat_space = spaces.flatten_space(self._space) + + @property + def space(self): + """Observation space, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_space + else: + return self._space + + @property + def current_observation(self): + """Current observation, return the flattened version if flatten is True.""" + if self.flatten: + return self._flat_observation + else: + return self._observation @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): @@ -417,3 +512,17 @@ class ObservationsHandler: handler.update_obs() return handler + + def describe_structure(self): + """Create a list of names for the features of the obs space. + + The order of labels follows the flattened version of the space. + """ + # as it turns out it's not possible to take the gym flattening function and apply it to our labels so we have + # to fake it. each component has to just hard-code the expected label order after flattening... + + labels = [] + for obs_comp in self.registered_obs_components: + labels.extend(obs_comp.structure) + + return labels diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index be4cc434..e56abf9d 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -318,7 +318,8 @@ class Primaite(Env): datetime.now(), self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(copy.deepcopy(self.env_obs)) + transaction.set_obs_space_pre(self.obs_handler._flat_observation) + # Load the action space into the transaction transaction.set_action_space(copy.deepcopy(action)) @@ -400,7 +401,7 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() # Load the new observation space into the transaction - transaction.set_obs_space_post(copy.deepcopy(self.env_obs)) + transaction.set_obs_space_post(self.obs_handler._flat_observation) # 8. Add the transaction to the list of transactions self.transaction_list.append(copy.deepcopy(transaction)) diff --git a/src/primaite/main.py b/src/primaite/main.py index f5e94509..4d83f604 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -325,6 +325,7 @@ def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, transaction_list=transaction_list, session_path=session_dir, timestamp_str=timestamp_str, + obs_space_description=env.obs_handler.describe_structure(), ) print("Updating Session Metadata file...") diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index 11e68af8..b2a4d40d 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -22,24 +22,12 @@ def turn_action_space_to_array(_action_space): return [str(_action_space)] -def turn_obs_space_to_array(_obs_space, _obs_assets, _obs_features): - """ - Turns observation space into a string array so it can be saved to csv. - - Args: - _obs_space: The observation space - _obs_assets: The number of assets (i.e. nodes or links) in the observation space - _obs_features: The number of features associated with the asset - """ - return_array = [] - for x in range(_obs_assets): - for y in range(_obs_features): - return_array.append(str(_obs_space[x][y])) - - return return_array - - -def write_transaction_to_file(transaction_list, session_path: Path, timestamp_str: str): +def write_transaction_to_file( + transaction_list, + session_path: Path, + timestamp_str: str, + obs_space_description: list, +): """ Writes transaction logs to file to support training evaluation. @@ -56,13 +44,13 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st # This will be tied into the PrimAITE Use Case so that they make sense template_transation = transaction_list[0] action_length = template_transation.action_space.size - obs_shape = template_transation.obs_space_post.shape - obs_assets = template_transation.obs_space_post.shape[0] - if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - obs_features = 1 - else: - obs_features = template_transation.obs_space_post.shape[1] + # obs_shape = template_transation.obs_space_post.shape + # obs_assets = template_transation.obs_space_post.shape[0] + # if len(obs_shape) == 1: + # bit of a workaround but I think the way transactions are written will change soon + # obs_features = 1 + # else: + # obs_features = template_transation.obs_space_post.shape[1] # Create the action space headers array action_header = [] @@ -70,12 +58,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st action_header.append("AS_" + str(x)) # Create the observation space headers array - obs_header_initial = [] - obs_header_new = [] - for x in range(obs_assets): - for y in range(obs_features): - obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) - obs_header_new.append("OSN_" + str(x) + "_" + str(y)) + obs_header_initial = [f"pre_{o}" for o in obs_space_description] + obs_header_new = [f"post_{o}" for o in obs_space_description] # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] @@ -98,12 +82,8 @@ def write_transaction_to_file(transaction_list, session_path: Path, timestamp_st csv_data = ( csv_data + turn_action_space_to_array(transaction.action_space) - + turn_obs_space_to_array( - transaction.obs_space_pre, obs_assets, obs_features - ) - + turn_obs_space_to_array( - transaction.obs_space_post, obs_assets, obs_features - ) + + transaction.obs_space_pre.tolist() + + transaction.obs_space_post.tolist() ) csv_writer.writerow(csv_data) From 73015802ece6109f18ac80044c440560bef24fdf Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 30 Jun 2023 09:08:13 +0100 Subject: [PATCH 18/43] #917 - Integrated the PrimaiteSession into all tests. - Ran a full pre-commit hook and thus encountered tons of fixes required --- .azure/azure-ci-build-pipeline.yaml | 2 +- .pre-commit-config.yaml | 6 +- docs/source/config.rst | 6 +- docs/source/primaite-dependencies.rst | 58 ++++- pyproject.toml | 11 + src/primaite/VERSION | 2 +- src/primaite/__init__.py | 19 +- src/primaite/acl/access_control_list.py | 28 ++- src/primaite/acl/acl_rule.py | 8 +- src/primaite/agents/agent.py | 190 ++++++++++---- src/primaite/agents/hardcoded_acl.py | 231 ++++++++++++------ src/primaite/agents/hardcoded_node.py | 64 +++-- src/primaite/agents/rllib.py | 115 +++++---- src/primaite/agents/sb3.py | 103 +++++--- src/primaite/agents/utils.py | 230 ++++++++++------- src/primaite/cli.py | 53 +++- src/primaite/common/enums.py | 9 + .../training/training_config_main.yaml | 8 +- src/primaite/config/lay_down_config.py | 17 +- src/primaite/config/training_config.py | 72 +++--- src/primaite/data_viz/__init__.py | 13 + src/primaite/data_viz/session_plots.py | 73 ++++++ src/primaite/environment/observations.py | 29 ++- src/primaite/environment/primaite_env.py | 108 ++++---- src/primaite/environment/reward.py | 23 +- src/primaite/links/link.py | 4 +- src/primaite/main.py | 5 +- src/primaite/nodes/active_node.py | 17 +- .../nodes/node_state_instruction_green.py | 4 +- .../nodes/node_state_instruction_red.py | 4 +- src/primaite/nodes/service_node.py | 4 +- src/primaite/notebooks/__init__.py | 2 +- src/primaite/pol/green_pol.py | 38 ++- src/primaite/pol/red_agent_pol.py | 33 ++- src/primaite/primaite_session.py | 175 ++++++++----- .../setup/_package_data/primaite_config.yaml | 11 + src/primaite/setup/reset_demo_notebooks.py | 6 +- src/primaite/setup/reset_example_configs.py | 6 +- src/primaite/setup/setup_app_dirs.py | 2 +- src/primaite/transactions/transaction.py | 26 +- src/primaite/utils/session_output_reader.py | 20 ++ src/primaite/utils/session_output_writer.py | 38 ++- tests/config/legacy/new_training_config.yaml | 4 +- .../main_config_LINK_TRAFFIC_LEVELS.yaml | 27 +- .../main_config_NODE_LINK_TABLE.yaml | 27 +- .../obs_tests/main_config_NODE_STATUSES.yaml | 27 +- .../obs_tests/main_config_without_obs.yaml | 27 +- .../one_node_states_on_off_main_config.yaml | 26 +- ..._space_fixed_blue_actions_main_config.yaml | 27 +- .../single_action_space_main_config.yaml | 27 +- tests/conftest.py | 128 +++++++++- .../test_primaite_main.py | 8 - .../__init__.py | 0 tests/mock_and_patch/get_session_path_mock.py | 24 ++ tests/test_active_node.py | 12 +- tests/test_observation_space.py | 210 +++++++++------- tests/test_primaite_session.py | 61 +++++ tests/test_resetting_node.py | 4 +- tests/test_reward.py | 38 +-- tests/test_service_node.py | 8 +- tests/test_single_action_space.py | 120 +++++---- tests/test_training_config.py | 4 +- 62 files changed, 1880 insertions(+), 802 deletions(-) create mode 100644 src/primaite/data_viz/__init__.py create mode 100644 src/primaite/data_viz/session_plots.py create mode 100644 src/primaite/utils/session_output_reader.py delete mode 100644 tests/e2e_integration_tests/test_primaite_main.py rename tests/{e2e_integration_tests => mock_and_patch}/__init__.py (100%) create mode 100644 tests/mock_and_patch/get_session_path_mock.py create mode 100644 tests/test_primaite_session.py diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 8bfdca02..691f71e9 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -59,4 +59,4 @@ steps: - script: | pytest tests/ - displayName: 'Run unmarked tests' + displayName: 'Run tests' diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a08b17b8..26cd5697 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,6 +13,9 @@ repos: rev: 23.1.0 hooks: - id: black + args: [ "--line-length=79" ] + additional_dependencies: + - jupyter - repo: http://github.com/pycqa/isort rev: 5.12.0 hooks: @@ -22,4 +25,5 @@ repos: rev: 6.0.0 hooks: - id: flake8 - additional_dependencies: [ flake8-docstrings ] + additional_dependencies: + - flake8-docstrings diff --git a/docs/source/config.rst b/docs/source/config.rst index 52748eec..22fd0c01 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -22,7 +22,7 @@ The environment config file consists of the following attributes: * **agent_framework** [enum] - + This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following: * NONE - Where a user developed agent is to be used @@ -30,14 +30,14 @@ The environment config file consists of the following attributes: * RLLIB - Ray RLlib. * **agent_identifier** - + This identifies the agent to use for the session. Select from one of the following: * A2C - Advantage Actor Critic * PPO - Proximal Policy Optimization * HARDCODED - A custom built deterministic agent * RANDOM - A Stochastic random agent - + * **action_type** [enum] diff --git a/docs/source/primaite-dependencies.rst b/docs/source/primaite-dependencies.rst index bf6bd6e3..48f835fe 100644 --- a/docs/source/primaite-dependencies.rst +++ b/docs/source/primaite-dependencies.rst @@ -47,6 +47,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | asttokens | 2.2.1 | Apache 2.0 | https://github.com/gristlabs/asttokens | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| astunparse | 1.6.3 | BSD License | https://github.com/simonpercivall/astunparse | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | attrs | 23.1.0 | MIT License | https://www.attrs.org/en/stable/changelog.html | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | backcall | 0.2.0 | BSD License | https://github.com/takluyver/backcall | @@ -103,6 +105,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | flake8 | 6.0.0 | MIT License | https://github.com/pycqa/flake8 | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| flatbuffers | 23.5.26 | Apache Software License | https://google.github.io/flatbuffers/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | fonttools | 4.39.4 | MIT License | http://github.com/fonttools/fonttools | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | fqdn | 1.5.1 | Mozilla Public License 2.0 (MPL 2.0) | https://github.com/ypcrts/fqdn | @@ -111,9 +115,13 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | furo | 2023.3.27 | MIT License | UNKNOWN | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| gast | 0.4.0 | BSD License | https://github.com/serge-sans-paille/gast/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | google-auth | 2.19.0 | Apache Software License | https://github.com/googleapis/google-auth-library-python | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| google-auth-oauthlib | 1.0.0 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib | +| google-auth-oauthlib | 0.4.6 | Apache Software License | https://github.com/GoogleCloudPlatform/google-auth-library-python-oauthlib | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| google-pasta | 0.2.0 | Apache Software License | https://github.com/google/pasta | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | grpcio | 1.51.3 | Apache Software License | https://grpc.io | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ @@ -121,6 +129,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | gymnasium-notices | 0.0.1 | MIT License | https://github.com/Farama-Foundation/gym-notices | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| h5py | 3.9.0 | BSD License | https://www.h5py.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | identify | 2.5.24 | MIT License | https://github.com/pre-commit/identify | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | idna | 3.4 | BSD License | https://github.com/kjd/idna | @@ -141,6 +151,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | isoduration | 20.11.0 | ISC License (ISCL) | https://github.com/bolsote/isoduration | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jax | 0.4.12 | Apache-2.0 | https://github.com/google/jax | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jedi | 0.18.2 | MIT License | https://github.com/davidhalter/jedi | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | json5 | 0.9.14 | Apache Software License | https://github.com/dpranke/pyjson5 | @@ -151,14 +163,14 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter-events | 0.6.3 | BSD License | http://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| jupyter-server | 1.24.0 | BSD License | https://jupyter-server.readthedocs.io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter-ydoc | 0.2.4 | BSD 3-Clause License | https://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter_client | 8.2.0 | BSD License | https://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter_core | 5.3.0 | BSD License | https://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| jupyter_server | 2.6.0 | BSD License | https://jupyter-server.readthedocs.io | -+-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter_server_fileid | 0.9.0 | BSD License | UNKNOWN | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyter_server_terminals | 0.4.4 | BSD License | https://jupyter.org | @@ -171,10 +183,14 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | jupyterlab_server | 2.22.1 | BSD License | https://jupyterlab-server.readthedocs.io | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| keras | 2.12.0 | Apache Software License | https://keras.io/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | kiwisolver | 1.4.4 | BSD License | UNKNOWN | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | lazy_loader | 0.2 | BSD License | https://github.com/scientific-python/lazy_loader | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| libclang | 16.0.0 | Apache Software License | https://github.com/sighingnow/libclang | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | lz4 | 4.3.2 | BSD License | https://github.com/python-lz4/python-lz4 | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | markdown-it-py | 2.2.0 | MIT License | https://github.com/executablebooks/markdown-it-py | @@ -183,19 +199,23 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | matplotlib-inline | 0.1.6 | BSD 3-Clause | https://github.com/ipython/matplotlib-inline | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| mavizstyle | 1.0.0 | UNKNOWN | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mccabe | 0.7.0 | MIT License | https://github.com/pycqa/mccabe | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mdurl | 0.1.2 | MIT License | https://github.com/executablebooks/mdurl | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mistune | 2.0.5 | BSD License | https://github.com/lepture/mistune | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ml-dtypes | 0.2.0 | Apache Software License | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mock | 5.0.2 | BSD License | http://mock.readthedocs.org/en/latest/ | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | mpmath | 1.3.0 | BSD License | http://mpmath.org/ | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | msgpack | 1.0.5 | Apache Software License | https://msgpack.org/ | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| nbclassic | 1.0.0 | BSD License | https://github.com/jupyter/nbclassic | +| nbclassic | 0.5.6 | BSD License | https://github.com/jupyter/nbclassic | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | nbclient | 0.8.0 | BSD License | https://jupyter.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ @@ -217,6 +237,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | oauthlib | 3.2.2 | BSD License | https://github.com/oauthlib/oauthlib | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| opt-einsum | 3.3.0 | MIT | https://github.com/dgasmith/opt_einsum | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | overrides | 7.3.1 | Apache License, Version 2.0 | https://github.com/mkorpela/overrides | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | packaging | 23.1 | Apache Software License; BSD License | https://github.com/pypa/packaging | @@ -231,11 +253,17 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | platformdirs | 3.5.1 | MIT License | https://github.com/platformdirs/platformdirs | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| plotly | 5.15.0 | MIT License | https://plotly.com/python/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | pluggy | 1.0.0 | MIT License | https://github.com/pytest-dev/pluggy | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | pre-commit | 2.20.0 | MIT License | https://github.com/pre-commit/pre-commit | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| primaite | 1.2.1 | GFX | UNKNOWN | +| primaite | 2.0.0rc1 | GFX | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| primaite | 2.0.0rc1 | GFX | UNKNOWN | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| primaite | 2.0.0rc1 | GFX | UNKNOWN | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | prometheus-client | 0.17.0 | Apache Software License | https://github.com/prometheus/client_python | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ @@ -295,6 +323,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | rsa | 4.9 | Apache Software License | https://stuvel.eu/rsa | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| ruff | 0.0.272 | MIT License | https://github.com/charliermarsh/ruff | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | scikit-image | 0.20.0 | BSD License | https://scikit-image.org | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | scipy | 1.10.1 | BSD License | https://scipy.org/ | @@ -335,14 +365,26 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | tabulate | 0.9.0 | MIT License | https://github.com/astanin/python-tabulate | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| tensorboard | 2.12.3 | Apache Software License | https://github.com/tensorflow/tensorboard | +| tenacity | 8.2.2 | Apache Software License | https://github.com/jd/tenacity | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ -| tensorboard-data-server | 0.7.0 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server | +| tensorboard | 2.11.2 | Apache Software License | https://github.com/tensorflow/tensorboard | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorboard-data-server | 0.6.1 | Apache Software License | https://github.com/tensorflow/tensorboard/tree/master/tensorboard/data/server | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | tensorboard-plugin-wit | 1.8.1 | Apache 2.0 | https://whatif-tool.dev | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | tensorboardX | 2.6 | MIT License | https://github.com/lanpa/tensorboardX | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-estimator | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-intel | 2.12.0 | Apache Software License | https://www.tensorflow.org/ | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| tensorflow-io-gcs-filesystem | 0.31.0 | Apache Software License | https://github.com/tensorflow/io | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| termcolor | 2.3.0 | MIT License | https://github.com/termcolor/termcolor | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | terminado | 0.17.1 | BSD License | https://github.com/jupyter/terminado | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | tifffile | 2023.4.12 | BSD License | https://www.cgohlke.com | @@ -377,6 +419,8 @@ +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | websocket-client | 1.5.2 | Apache Software License | https://github.com/websocket-client/websocket-client.git | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ +| wrapt | 1.14.1 | BSD License | https://github.com/GrahamDumpleton/wrapt | ++-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | y-py | 0.5.9 | MIT License | https://github.com/y-crdt/ypy | +-------------------------------+-------------+--------------------------------------------------------------------------------------------------+-------------------------------------------------------------------------------+ | ypy-websocket | 0.8.2 | UNKNOWN | https://github.com/y-crdt/ypy-websocket | diff --git a/pyproject.toml b/pyproject.toml index aa9f5fdc..09b60777 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,8 @@ dependencies = [ "networkx==3.1", "numpy==1.23.5", "platformdirs==3.5.1", + "plotly==5.15.0", + "polars==0.18.4", "PyYAML==6.0", "ray[rllib]==2.2.0", "stable-baselines3==1.6.2", @@ -69,3 +71,12 @@ tensorflow = [ [project.scripts] primaite = "primaite.cli:app" + +[tool.isort] +profile = "black" +line_length = 79 +force_sort_within_sections = "False" +order_by_type = "False" + +[tool.black] +line-length = 79 diff --git a/src/primaite/VERSION b/src/primaite/VERSION index 3068ee27..4111d137 100644 --- a/src/primaite/VERSION +++ b/src/primaite/VERSION @@ -1 +1 @@ -2.0.0rc1 \ No newline at end of file +2.0.0rc1 diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index 64857c80..e753b4ef 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -3,12 +3,10 @@ import logging import logging.config import sys from bisect import bisect -from logging import Formatter, LogRecord, StreamHandler -from logging import Logger +from logging import Formatter, Logger, LogRecord, StreamHandler from logging.handlers import RotatingFileHandler from pathlib import Path -from typing import Dict -from typing import Final +from typing import Dict, Final import pkg_resources import yaml @@ -21,7 +19,6 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") def _get_primaite_config(): config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" if not config_path.exists(): - config_path = Path( pkg_resources.resource_filename( "primaite", "setup/_package_data/primaite_config.yaml" @@ -37,7 +34,9 @@ def _get_primaite_config(): "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } - primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]] + primaite_config["log_level"] = log_level_map[ + primaite_config["logging"]["log_level"] + ] return primaite_config @@ -111,9 +110,13 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( { logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], - logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], + logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"][ + "WARNING" + ], logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], - logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"] + logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"][ + "CRITICAL" + ], } ) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index 284ed764..a147d963 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -10,7 +10,9 @@ class AccessControlList: def __init__(self): """Init.""" - self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules + self.acl: Dict[ + str, AccessControlList + ] = {} # A dictionary of ACL Rules def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): """ @@ -37,13 +39,17 @@ class AccessControlList: _rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY" ) - or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY") + or ( + _rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY" + ) ): return True else: return False - def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): + def is_blocked( + self, _source_ip_address, _dest_ip_address, _protocol, _port + ): """ Checks for rules that block a protocol / port. @@ -87,7 +93,9 @@ class AccessControlList: _protocol: the protocol _port: the port """ - new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) + new_rule = ACLRule( + _permission, _source_ip, _dest_ip, _protocol, str(_port) + ) hash_value = hash(new_rule) self.acl[hash_value] = new_rule @@ -102,7 +110,9 @@ class AccessControlList: _protocol: the protocol _port: the port """ - rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) + rule = ACLRule( + _permission, _source_ip, _dest_ip, _protocol, str(_port) + ) hash_value = hash(rule) # There will not always be something 'popable' since the agent will be trying random things try: @@ -114,7 +124,9 @@ class AccessControlList: """Removes all rules.""" self.acl.clear() - def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): + def get_dictionary_hash( + self, _permission, _source_ip, _dest_ip, _protocol, _port + ): """ Produces a hash value for a rule. @@ -128,6 +140,8 @@ class AccessControlList: Returns: Hash value based on rule parameters. """ - rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) + rule = ACLRule( + _permission, _source_ip, _dest_ip, _protocol, str(_port) + ) hash_value = hash(rule) return hash_value diff --git a/src/primaite/acl/acl_rule.py b/src/primaite/acl/acl_rule.py index ef631a70..05daecc4 100644 --- a/src/primaite/acl/acl_rule.py +++ b/src/primaite/acl/acl_rule.py @@ -30,7 +30,13 @@ class ACLRule: Returns hash of core parameters. """ return hash( - (self.permission, self.source_ip, self.dest_ip, self.protocol, self.port) + ( + self.permission, + self.source_ip, + self.dest_ip, + self.protocol, + self.port, + ) ) def get_permission(self): diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index f545a3cb..c76583c0 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -1,27 +1,31 @@ from __future__ import annotations + import json import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Optional, Final, Dict, Union +from typing import Dict, Final, Optional, Union from uuid import uuid4 import yaml import primaite from primaite import getLogger, SESSIONS_DIR -from primaite.config import lay_down_config -from primaite.config import training_config +from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig +from primaite.data_viz.session_plots import plot_av_reward_per_episode from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) -def _get_session_path(session_timestamp: datetime) -> Path: +def get_session_path(session_timestamp: datetime) -> Path: """ - Get a temp directory session path the test session will output to. + Get the directory path the session will output to. + + This is set in the format of: + ~/primaite/sessions//_. :param session_timestamp: This is the datetime that the session started. :return: The session directory path. @@ -35,13 +39,15 @@ def _get_session_path(session_timestamp: datetime) -> Path: class AgentSessionABC(ABC): + """ + An ABC that manages training and/or evaluation of agents in PrimAITE. + + This class cannot be directly instantiated and must be inherited from + with all implemented abstract methods implemented. + """ @abstractmethod - def __init__( - self, - training_config_path, - lay_down_config_path - ): + def __init__(self, training_config_path, lay_down_config_path): if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path @@ -66,9 +72,8 @@ class AgentSessionABC(ABC): self._uuid = str(uuid4()) self.session_timestamp: datetime = datetime.now() "The session timestamp" - self.session_path = _get_session_path(self.session_timestamp) + self.session_path = get_session_path(self.session_timestamp) "The Session path" - self.checkpoints_path.mkdir(parents=True, exist_ok=True) @property def timestamp_str(self) -> str: @@ -78,17 +83,23 @@ class AgentSessionABC(ABC): @property def learning_path(self) -> Path: """The learning outputs path.""" - return self.session_path / "learning" + path = self.session_path / "learning" + path.mkdir(exist_ok=True, parents=True) + return path @property def evaluation_path(self) -> Path: """The evaluation outputs path.""" - return self.session_path / "evaluation" + path = self.session_path / "evaluation" + path.mkdir(exist_ok=True, parents=True) + return path @property def checkpoints_path(self) -> Path: """The Session checkpoints path.""" - return self.learning_path / "checkpoints" + path = self.learning_path / "checkpoints" + path.mkdir(exist_ok=True, parents=True) + return path @property def uuid(self): @@ -118,14 +129,8 @@ class AgentSessionABC(ABC): "uuid": self.uuid, "start_datetime": self.session_timestamp.isoformat(), "end_datetime": None, - "learning": { - "total_episodes": None, - "total_time_steps": None - }, - "evaluation": { - "total_episodes": None, - "total_time_steps": None - }, + "learning": {"total_episodes": None, "total_time_steps": None}, + "evaluation": {"total_episodes": None, "total_time_steps": None}, "env": { "training_config": self._training_config.to_dict( json_serializable=True @@ -156,11 +161,19 @@ class AgentSessionABC(ABC): metadata_dict["end_datetime"] = datetime.now().isoformat() if not self.is_eval: - metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa - metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa + metadata_dict["learning"][ + "total_episodes" + ] = self._env.episode_count # noqa + metadata_dict["learning"][ + "total_time_steps" + ] = self._env.total_step_count # noqa else: - metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa - metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa + metadata_dict["evaluation"][ + "total_episodes" + ] = self._env.episode_count # noqa + metadata_dict["evaluation"][ + "total_time_steps" + ] = self._env.total_step_count # noqa filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -187,26 +200,47 @@ class AgentSessionABC(ABC): @abstractmethod def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Train the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ if self._can_learn: _LOGGER.info("Finished learning") _LOGGER.debug("Writing transactions") self._update_session_metadata_file() self._can_evaluate = True self.is_eval = False + self._plot_av_reward_per_episode(learning_session=True) @abstractmethod def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa self.is_eval = True + self._plot_av_reward_per_episode(learning_session=False) _LOGGER.info("Finished evaluation") @abstractmethod @@ -216,6 +250,7 @@ class AgentSessionABC(ABC): @classmethod @abstractmethod def load(cls, path: Union[str, Path]) -> AgentSessionABC: + """Load an agent from file.""" if not isinstance(path, Path): path = Path(path) @@ -246,21 +281,56 @@ class AgentSessionABC(ABC): else: # Session path does not exist - msg = f"Failed to load PrimAITE Session, path does not exist: {path}" + msg = ( + f"Failed to load PrimAITE Session, path does not exist: {path}" + ) _LOGGER.error(msg) raise FileNotFoundError(msg) pass @abstractmethod def save(self): + """Save the agent.""" self._agent.save(self.session_path) @abstractmethod def export(self): + """Export the agent to transportable file format.""" pass + def close(self): + """Closes the agent.""" + self._env.episode_av_reward_writer.close() # noqa + self._env.transaction_writer.close() # noqa + + def _plot_av_reward_per_episode(self, learning_session: bool = True): + # self.close() + title = f"PrimAITE Session {self.timestamp_str} " + subtitle = str(self._training_config) + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + image_file = f"average_reward_per_episode_{self.timestamp_str}.png" + if learning_session: + title += "(Learning)" + path = self.learning_path / csv_file + image_path = self.learning_path / image_file + else: + title += "(Evaluation)" + path = self.evaluation_path / csv_file + image_path = self.evaluation_path / image_file + + fig = plot_av_reward_per_episode(path, title, subtitle) + fig.write_image(image_path) + _LOGGER.debug(f"Saved average rewards per episode plot to: {path}") + class HardCodedAgentSessionABC(AgentSessionABC): + """ + An Agent Session ABC for evaluation deterministic agents. + + This class cannot be directly instantiated and must be inherited from + with all implemented abstract methods implemented. + """ + def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) self._setup() @@ -270,13 +340,12 @@ class HardCodedAgentSessionABC(AgentSessionABC): training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, session_path=self.session_path, - timestamp_str=self.timestamp_str + timestamp_str=self.timestamp_str, ) super()._setup() self._can_learn = False self._can_evaluate = True - def _save_checkpoint(self): pass @@ -284,11 +353,20 @@ class HardCodedAgentSessionABC(AgentSessionABC): pass def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Train the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ _LOGGER.warning("Deterministic agents cannot learn") @abstractmethod @@ -296,20 +374,31 @@ class HardCodedAgentSessionABC(AgentSessionABC): pass def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ + self._env.set_as_eval() # noqa + self.is_eval = True + if not time_steps: time_steps = self._training_config.num_steps if not episodes: episodes = self._training_config.num_episodes - + obs = self._env.reset() for episode in range(episodes): # Reset env and collect initial observation - obs = self._env.reset() for step in range(time_steps): # Calculate action action = self._calculate_action(obs) @@ -322,15 +411,18 @@ class HardCodedAgentSessionABC(AgentSessionABC): # Introduce a delay between steps time.sleep(self._training_config.time_delay / 1000) + obs = self._env.reset() self._env.close() - super().evaluate() @classmethod def load(cls): + """Load an agent from file.""" _LOGGER.warning("Deterministic agents cannot be loaded") def save(self): + """Save the agent.""" _LOGGER.warning("Deterministic agents cannot be saved") def export(self): + """Export the agent to transportable file format.""" _LOGGER.warning("Deterministic agents cannot be exported") diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index 4ad08f6e..f70320f1 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -11,9 +11,13 @@ from primaite.common.enums import HardCodedAgentView class HardCodedACLAgent(HardCodedAgentSessionABC): + """An Agent Session class that implements a deterministic ACL agent.""" def _calculate_action(self, obs): - if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: + if ( + self._training_config.hard_coded_agent_view + == HardCodedAgentView.BASIC + ): # Basic view action using only the current observation return self._calculate_action_basic_view(obs) else: @@ -22,6 +26,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return self._calculate_action_full_view(obs) def get_blocked_green_iers(self, green_iers, acl, nodes): + """ + Get blocked green IERs. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ blocked_green_iers = {} for green_ier_id, green_ier in green_iers.items(): @@ -33,8 +43,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): port = green_ier.get_port() # Can be blocked by an ACL or by default (no allow rule exists) - if acl.is_blocked(source_node_address, dest_node_address, protocol, - port): + if acl.is_blocked( + source_node_address, dest_node_address, protocol, port + ): blocked_green_iers[green_ier_id] = green_ier return blocked_green_iers @@ -42,8 +53,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_matching_acl_rules_for_ier(self, ier, acl, nodes): """ Get matching ACL rules for an IER. - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ source_node_id = ier.get_source_node_id() source_node_address = nodes[source_node_id].ip_address dest_node_id = ier.get_dest_node_id() @@ -51,17 +64,22 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): protocol = ier.get_protocol() # e.g. 'TCP' port = ier.get_port() - matching_rules = acl.get_relevant_rules(source_node_address, - dest_node_address, protocol, - port) + matching_rules = acl.get_relevant_rules( + source_node_address, dest_node_address, protocol, port + ) return matching_rules def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): """ Get blocking ACL rules for an IER. - Warning: Can return empty dict but IER can still be blocked by default (No ALLOW rule, therefore blocked) - """ + .. warning:: + Can return empty dict but IER can still be blocked by default + (No ALLOW rule, therefore blocked). + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) blocked_rules = {} @@ -74,8 +92,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def get_allow_acl_rules_for_ier(self, ier, acl, nodes): """ Get all allowing ACL rules for an IER. - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ matching_rules = self.get_matching_acl_rules_for_ier(ier, acl, nodes) allowed_rules = {} @@ -85,9 +105,22 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules - def get_matching_acl_rules(self, source_node_id, dest_node_id, protocol, - port, acl, - nodes, services_list): + def get_matching_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get matching ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if source_node_id != "ANY": source_node_address = nodes[str(source_node_id)].ip_address else: @@ -100,21 +133,39 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if protocol != "ANY": protocol = services_list[ - protocol - 1] # -1 as dont have to account for ANY in list of services + protocol - 1 + ] # -1 as dont have to account for ANY in list of services - matching_rules = acl.get_relevant_rules(source_node_address, - dest_node_address, protocol, - port) + matching_rules = acl.get_relevant_rules( + source_node_address, dest_node_address, protocol, port + ) return matching_rules - def get_allow_acl_rules(self, source_node_id, dest_node_id, protocol, - port, acl, - nodes, services_list): - matching_rules = self.get_matching_acl_rules(source_node_id, - dest_node_id, - protocol, port, acl, - nodes, - services_list) + def get_allow_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get the ALLOW ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + matching_rules = self.get_matching_acl_rules( + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ) allowed_rules = {} for rule_key, rule_value in matching_rules.items(): @@ -123,14 +174,31 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return allowed_rules - def get_deny_acl_rules(self, source_node_id, dest_node_id, protocol, port, - acl, - nodes, services_list): - matching_rules = self.get_matching_acl_rules(source_node_id, - dest_node_id, - protocol, port, acl, - nodes, - services_list) + def get_deny_acl_rules( + self, + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ): + """ + Get the DENY ACL rules. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ + matching_rules = self.get_matching_acl_rules( + source_node_id, + dest_node_id, + protocol, + port, + acl, + nodes, + services_list, + ) allowed_rules = {} for rule_key, rule_value in matching_rules.items(): @@ -141,7 +209,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): def _calculate_action_full_view(self, obs): """ - Given an observation and the environment calculate a good acl-based action for the blue agent to take + Calculate a good acl-based action for the blue agent to take. Knowledge of just the observation space is insufficient for a perfect solution, as we need to know: @@ -167,8 +235,10 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): nodes once a service becomes overwhelmed. However currently the ACL action space has no way of reversing an overwhelmed state, so we don't do this. + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ - #obs = convert_to_old_obs(obs) + # obs = convert_to_old_obs(obs) r_obs = transform_change_obs_readable(obs) _, _, _, *s = r_obs @@ -184,7 +254,6 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): for service_num, service_states in enumerate(s): for x, service_state in enumerate(service_states): if service_state == "COMPROMISED": - action_source_id = x + 1 # +1 as 0 is any action_destination_id = "ANY" action_protocol = service_num + 1 # +1 as 0 is any @@ -215,19 +284,23 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action_permission = "ALLOW" action_source_ip = rule.get_source_ip() action_source_id = int( - get_node_of_ip(action_source_ip, self._env.nodes)) + get_node_of_ip(action_source_ip, self._env.nodes) + ) action_destination_ip = rule.get_dest_ip() action_destination_id = int( - get_node_of_ip(action_destination_ip, - self._env.nodes)) + get_node_of_ip( + action_destination_ip, self._env.nodes + ) + ) action_protocol_name = rule.get_protocol() action_protocol = ( - self._env.services_list.index( - action_protocol_name) + 1 + self._env.services_list.index(action_protocol_name) + + 1 ) # convert name e.g. 'TCP' to index action_port_name = rule.get_port() - action_port = self._env.ports_list.index( - action_port_name) + 1 # convert port name e.g. '80' to index + action_port = ( + self._env.ports_list.index(action_port_name) + 1 + ) # convert port name e.g. '80' to index found_action = True break @@ -258,21 +331,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if not found_action: # Which Green IERS are blocked blocked_green_iers = self.get_blocked_green_iers( - self._env.green_iers, self._env.acl, - self._env.nodes) + self._env.green_iers, self._env.acl, self._env.nodes + ) for ier_key, ier in blocked_green_iers.items(): - # Which ALLOW rules are allowing this IER (none) - allowing_rules = self.get_allow_acl_rules_for_ier(ier, - self._env.acl, - self._env.nodes) + allowing_rules = self.get_allow_acl_rules_for_ier( + ier, self._env.acl, self._env.nodes + ) # If there are no blocking rules, it may be being blocked by default # If there is already an allow rule node_id_to_check = int(ier.get_source_node_id()) service_name_to_check = ier.get_protocol() service_id_to_check = self._env.services_list.index( - service_name_to_check) + service_name_to_check + ) # Service state of the the source node in the ier service_state = s[service_id_to_check][node_id_to_check - 1] @@ -283,11 +356,13 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action_source_id = int(ier.get_source_node_id()) action_destination_id = int(ier.get_dest_node_id()) action_protocol_name = ier.get_protocol() - action_protocol = self._env.services_list.index( - action_protocol_name) + 1 # convert name e.g. 'TCP' to index + action_protocol = ( + self._env.services_list.index(action_protocol_name) + 1 + ) # convert name e.g. 'TCP' to index action_port_name = ier.get_port() - action_port = self._env.ports_list.index( - action_port_name) + 1 # convert port name e.g. '80' to index + action_port = ( + self._env.ports_list.index(action_port_name) + 1 + ) # convert port name e.g. '80' to index found_action = True break @@ -311,19 +386,25 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): return action def _calculate_action_basic_view(self, obs): - """Given an observation calculate a good acl-based action for the blue agent to take + """Calculate a good acl-based action for the blue agent to take. - Uses ONLY information from the current observation with NO knowledge of previous actions taken and - NO reward feedback. + Uses ONLY information from the current observation with NO knowledge + of previous actions taken and NO reward feedback. - We rely on randomness to select the precise action, as we want to block all traffic originating from - a compromised node, without being able to tell: + We rely on randomness to select the precise action, as we want to + block all traffic originating from a compromised node, without being + able to tell: 1. Which ACL rules already exist - 1. Which actions the agent has already tried. + 2. Which actions the agent has already tried. - There is a high probability that the correct rule will not be deleted before the state becomes overwhelmed. + There is a high probability that the correct rule will not be deleted + before the state becomes overwhelmed. - Currently a deny rule does not overwrite an allow rule. The allow rules must be deleted. + Currently, a deny rule does not overwrite an allow rule. The allow + rules must be deleted. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) @@ -333,27 +414,35 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): s = [*s] number_of_nodes = len( - [i for i in o if i != "NONE"]) # number of nodes (not links) + [i for i in o if i != "NONE"] + ) # number of nodes (not links) for service_num, service_states in enumerate(s): - comprimised_states = [n for n, i in enumerate(service_states) if - i == "COMPROMISED"] + comprimised_states = [ + n for n, i in enumerate(service_states) if i == "COMPROMISED" + ] if len(comprimised_states) == 0: # No states are COMPROMISED, try the next service continue - compromised_node = np.random.choice( - comprimised_states) + 1 # +1 as 0 would be any + compromised_node = ( + np.random.choice(comprimised_states) + 1 + ) # +1 as 0 would be any action_decision = "DELETE" action_permission = "ALLOW" action_source_ip = compromised_node # Randomly select a destination ID to block action_destination_ip = np.random.choice( - list(range(1, number_of_nodes + 1)) + ["ANY"]) - action_destination_ip = int( - action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip + list(range(1, number_of_nodes + 1)) + ["ANY"] + ) + action_destination_ip = ( + int(action_destination_ip) + if action_destination_ip != "ANY" + else action_destination_ip + ) action_protocol = service_num + 1 # +1 as 0 is any # Randomly select a port to block - # Bad assumption that number of protocols equals number of ports AND no rules exist with an ANY port + # Bad assumption that number of protocols equals number of ports + # AND no rules exist with an ANY port action_port = np.random.choice(list(range(1, len(s) + 1))) action = [ diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index 6db43da6..e258edb0 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,16 +1,21 @@ from primaite.agents.agent import HardCodedAgentSessionABC from primaite.agents.utils import ( get_new_action, - transform_change_obs_readable, -) -from primaite.agents.utils import ( transform_action_node_enum, + transform_change_obs_readable, ) class HardCodedNodeAgent(HardCodedAgentSessionABC): + """An Agent Session class that implements a deterministic Node agent.""" + def _calculate_action(self, obs): - """Given an observation calculate a good node-based action for the blue agent to take""" + """ + Calculate a good node-based action for the blue agent to take. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ action_dict = self._env.action_dict r_obs = transform_change_obs_readable(obs) _, o, os, *s = r_obs @@ -18,7 +23,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): if len(r_obs) == 4: # only 1 service s = [*s] - # Check in order of most important states (order doesn't currently matter, but it probably should) + # Check in order of most important states (order doesn't currently + # matter, but it probably should) # First see if any OS states are compromised for x, os_state in enumerate(os): if os_state == "COMPROMISED": @@ -26,8 +32,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): action_node_property = "OS" property_action = "PATCHING" action_service_index = 0 # does nothing isn't relevant for os - action = [action_node_id, action_node_property, - property_action, action_service_index] + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action) action = get_new_action(action, action_dict) # We can only perform 1 action on each step @@ -44,8 +54,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): property_action = "PATCHING" action_service_index = service_num - action = [action_node_id, action_node_property, - property_action, action_service_index] + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action) action = get_new_action(action, action_dict) # We can only perform 1 action on each step @@ -63,8 +77,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): property_action = "PATCHING" action_service_index = service_num - action = [action_node_id, action_node_property, - property_action, action_service_index] + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action) action = get_new_action(action, action_dict) # We can only perform 1 action on each step @@ -75,10 +93,18 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): if os_state == "OFF": action_node_id = x + 1 action_node_property = "OPERATING" - property_action = "ON" # Why reset it when we can just turn it on - action_service_index = 0 # does nothing isn't relevant for operating state - action = [action_node_id, action_node_property, - property_action, action_service_index] + property_action = ( + "ON" # Why reset it when we can just turn it on + ) + action_service_index = ( + 0 # does nothing isn't relevant for operating state + ) + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action, action_dict) action = get_new_action(action, action_dict) # We can only perform 1 action on each step @@ -89,8 +115,12 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): action_node_property = "NONE" property_action = "ON" action_service_index = 0 - action = [action_node_id, action_node_property, property_action, - action_service_index] + action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] action = transform_action_node_enum(action) action = get_new_action(action, action_dict) diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 8a6428bb..35ae1b53 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -1,28 +1,35 @@ +from __future__ import annotations + import json from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Optional, Union +import tensorflow as tf from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig from ray.rllib.algorithms.ppo import PPOConfig from ray.tune.logger import UnifiedLogger from ray.tune.registry import register_env -import tensorflow as tf + from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import AgentFramework, AgentIdentifier, \ - DeepLearningFramework +from primaite.common.enums import ( + AgentFramework, + AgentIdentifier, + DeepLearningFramework, +) from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) + def _env_creator(env_config): return Primaite( training_config_path=env_config["training_config_path"], lay_down_config_path=env_config["lay_down_config_path"], session_path=env_config["session_path"], - timestamp_str=env_config["timestamp_str"] + timestamp_str=env_config["timestamp_str"], ) @@ -37,16 +44,15 @@ def _custom_log_creator(session_path: Path): class RLlibAgent(AgentSessionABC): + """An AgentSession class that implements a Ray RLlib agent.""" - def __init__( - self, - training_config_path, - lay_down_config_path - ): + def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.RLLIB: - msg = (f"Expected RLLIB agent_framework, " - f"got {self._training_config.agent_framework}") + msg = ( + f"Expected RLLIB agent_framework, " + f"got {self._training_config.agent_framework}" + ) _LOGGER.error(msg) raise ValueError(msg) if self._training_config.agent_identifier == AgentIdentifier.PPO: @@ -54,8 +60,10 @@ class RLlibAgent(AgentSessionABC): elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_config_class = A2CConfig else: - msg = ("Expected PPO or A2C agent_identifier, " - f"got {self._training_config.agent_identifier.value}") + msg = ( + "Expected PPO or A2C agent_identifier, " + f"got {self._training_config.agent_identifier.value}" + ) _LOGGER.error(msg) raise ValueError(msg) self._agent_config: PPOConfig @@ -86,8 +94,12 @@ class RLlibAgent(AgentSessionABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._current_result["episodes_total"] - metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] + metadata_dict["total_episodes"] = self._current_result[ + "episodes_total" + ] + metadata_dict["total_time_steps"] = self._current_result[ + "timesteps_total" + ] filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -106,43 +118,48 @@ class RLlibAgent(AgentSessionABC): training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, session_path=self.session_path, - timestamp_str=self.timestamp_str - ) + timestamp_str=self.timestamp_str, + ), ) self._agent_config.training( train_batch_size=self._training_config.num_steps ) - self._agent_config.framework( - framework="tf" - ) + self._agent_config.framework(framework="tf") self._agent_config.rollouts( num_rollout_workers=1, num_envs_per_worker=1, - horizon=self._training_config.num_steps + horizon=self._training_config.num_steps, ) self._agent: Algorithm = self._agent_config.build( logger_creator=_custom_log_creator(self.session_path) ) - def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] if checkpoint_n > 0 and episode_count > 0: - if ( - (episode_count % checkpoint_n == 0) - or (episode_count == self._training_config.num_episodes) + if (episode_count % checkpoint_n == 0) or ( + episode_count == self._training_config.num_episodes ): - self._agent.save(self.checkpoints_path) + self._agent.save(str(self.checkpoints_path)) def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ # Temporarily override train_batch_size and horizon if time_steps: self._agent_config.train_batch_size = time_steps @@ -150,37 +167,53 @@ class RLlibAgent(AgentSessionABC): if not episodes: episodes = self._training_config.num_episodes - _LOGGER.info(f"Beginning learning for {episodes} episodes @" - f" {time_steps} time steps...") + _LOGGER.info( + f"Beginning learning for {episodes} episodes @" + f" {time_steps} time steps..." + ) for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() - if self._training_config.deep_learning_framework != DeepLearningFramework.TORCH: + if ( + self._training_config.deep_learning_framework + != DeepLearningFramework.TORCH + ): policy = self._agent.get_policy() tf.compat.v1.summary.FileWriter( - self.session_path / "ray_results", - policy.get_session().graph + self.session_path / "ray_results", policy.get_session().graph ) super().learn() self._agent.stop() def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ raise NotImplementedError def _get_latest_checkpoint(self): raise NotImplementedError @classmethod - def load(cls): + def load(cls, path: Union[str, Path]) -> RLlibAgent: + """Load an agent from file.""" raise NotImplementedError def save(self): + """Save the agent.""" raise NotImplementedError def export(self): + """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 328e6286..8d5dd633 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,27 +1,30 @@ -from typing import Optional +from __future__ import annotations + +from pathlib import Path +from typing import Optional, Union import numpy as np -from stable_baselines3 import PPO, A2C +from stable_baselines3 import A2C, PPO from stable_baselines3.ppo import MlpPolicy as PPOMlp from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import AgentIdentifier, AgentFramework +from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) class SB3Agent(AgentSessionABC): - def __init__( - self, - training_config_path, - lay_down_config_path - ): + """An AgentSession class that implements a Stable Baselines3 agent.""" + + def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.SB3: - msg = (f"Expected SB3 agent_framework, " - f"got {self._training_config.agent_framework}") + msg = ( + f"Expected SB3 agent_framework, " + f"got {self._training_config.agent_framework}" + ) _LOGGER.error(msg) raise ValueError(msg) if self._training_config.agent_identifier == AgentIdentifier.PPO: @@ -29,8 +32,10 @@ class SB3Agent(AgentSessionABC): elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_class = A2C else: - msg = ("Expected PPO or A2C agent_identifier, " - f"got {self._training_config.agent_identifier.value}") + msg = ( + "Expected PPO or A2C agent_identifier, " + f"got {self._training_config.agent_identifier}" + ) _LOGGER.error(msg) raise ValueError(msg) @@ -52,25 +57,26 @@ class SB3Agent(AgentSessionABC): training_config_path=self._training_config_path, lay_down_config_path=self._lay_down_config_path, session_path=self.session_path, - timestamp_str=self.timestamp_str + timestamp_str=self.timestamp_str, ) self._agent = self._agent_class( PPOMlp, self._env, verbose=self.output_verbose_level, n_steps=self._training_config.num_steps, - tensorboard_log=self._tensorboard_log_path + tensorboard_log=self._tensorboard_log_path, ) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count if checkpoint_n > 0 and episode_count > 0: - if ( - (episode_count % checkpoint_n == 0) - or (episode_count == self._training_config.num_episodes) + if (episode_count % checkpoint_n == 0) or ( + episode_count == self._training_config.num_episodes ): - checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + checkpoint_path = ( + self.checkpoints_path / f"sb3ppo_{episode_count}.zip" + ) self._agent.save(checkpoint_path) _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") @@ -78,33 +84,54 @@ class SB3Agent(AgentSessionABC): pass def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Train the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param kwargs: Any agent-specific key-word args to be passed. + """ if not time_steps: time_steps = self._training_config.num_steps if not episodes: episodes = self._training_config.num_episodes self.is_eval = False - _LOGGER.info(f"Beginning learning for {episodes} episodes @" - f" {time_steps} time steps...") + _LOGGER.info( + f"Beginning learning for {episodes} episodes @" + f" {time_steps} time steps..." + ) for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() - self._env.close() + self.close() super().learn() def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - deterministic: bool = True, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + deterministic: bool = True, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of steps per episode. Optional. If not + passed, the value from the training config will be used. + :param episodes: The number of episodes. Optional. If not + passed, the value from the training config will be used. + :param deterministic: Whether the evaluation is deterministic. + :param kwargs: Any agent-specific key-word args to be passed. + """ if not time_steps: time_steps = self._training_config.num_steps @@ -116,27 +143,31 @@ class SB3Agent(AgentSessionABC): deterministic_str = "deterministic" else: deterministic_str = "non-deterministic" - _LOGGER.info(f"Beginning {deterministic_str} evaluation for " - f"{episodes} episodes @ {time_steps} time steps...") + _LOGGER.info( + f"Beginning {deterministic_str} evaluation for " + f"{episodes} episodes @ {time_steps} time steps..." + ) for episode in range(episodes): obs = self._env.reset() for step in range(time_steps): action, _states = self._agent.predict( - obs, - deterministic=deterministic + obs, deterministic=deterministic ) if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) - _LOGGER.info(f"Finished evaluation") + super().evaluate() @classmethod - def load(self): + def load(cls, path: Union[str, Path]) -> SB3Agent: + """Load an agent from file.""" raise NotImplementedError def save(self): + """Save the agent.""" raise NotImplementedError def export(self): + """Export the agent to transportable file format.""" raise NotImplementedError diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index a4eadc3b..c3e67fdf 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -4,9 +4,9 @@ from primaite.common.enums import ( HardwareState, LinkStatus, NodeHardwareAction, + NodePOLType, NodeSoftwareAction, SoftwareState, - NodePOLType ) @@ -16,14 +16,17 @@ def transform_action_node_readable(action): example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_node_property = NodePOLType(action[1]).name if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[ - 2 - ] <= 1: + elif ( + action_node_property == "OS" or action_node_property == "SERVICE" + ) and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -38,6 +41,9 @@ def transform_action_acl_readable(action): example: [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} action_permissions = {0: "DENY", 1: "ALLOW"} @@ -62,6 +68,9 @@ def is_valid_node_action(action): Does NOT consider: - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_node_readable(action) @@ -77,7 +86,10 @@ def is_valid_node_action(action): if node_property == "OPERATING" and node_action == "PATCHING": # Operating State cannot PATCH return False - if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + if node_property != "OPERATING" and node_action not in [ + "NONE", + "PATCHING", + ]: # Software States can only do Nothing or Patch return False return True @@ -92,6 +104,9 @@ def is_valid_acl_action(action): Does NOT consider: - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_acl_readable(action) @@ -118,7 +133,12 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """Harsher version of valid acl actions, does not allow action.""" + """ + Harsher version of valid acl actions, does not allow action. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if is_valid_acl_action(action) is False: return False @@ -136,13 +156,15 @@ def is_valid_acl_action_extra(action): return True - def transform_change_obs_readable(obs): - """Transform list of transactions to readable list of each observation property + """ + Transform list of transactions to readable list of each observation property. example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 2], ['OFF', 'ON'], ['GOOD', 'GOOD'], ['COMPROMISED', 'GOOD']] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ ids = [i for i in obs[:, 0]] operating_states = [HardwareState(i).name for i in obs[:, 1]] @@ -151,7 +173,9 @@ def transform_change_obs_readable(obs): for service in range(3, obs.shape[1]): # Links bit/s don't have a service state - service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] + service_states = [ + SoftwareState(i).name if i <= 4 else i for i in obs[:, service] + ] new_obs.append(service_states) return new_obs @@ -159,10 +183,13 @@ def transform_change_obs_readable(obs): def transform_obs_readable(obs): """ - example: - np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] - """ + Transform observation to readable format. + np.array([[1,2,1,3],[2,1,1,1]]) -> [[1, 'OFF', 'GOOD', 'COMPROMISED'], [2, 'ON', 'GOOD', 'GOOD']] + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ changed_obs = transform_change_obs_readable(obs) new_obs = list(zip(*changed_obs)) # Convert list of tuples to list of lists @@ -172,7 +199,12 @@ def transform_obs_readable(obs): def convert_to_new_obs(obs, num_nodes=10): - """Convert original gym Box observation space to new multiDiscrete observation space""" + """ + Convert original gym Box observation space to new multiDiscrete observation space. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ # Remove ID columns, remove links and flatten to MultiDiscrete observation space new_obs = obs[:num_nodes, 1:].flatten() return new_obs @@ -180,7 +212,9 @@ def convert_to_new_obs(obs, num_nodes=10): def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): """ - Convert to old observation, links filled with 0's as no information is included in new observation space + Convert to old observation. + + Links filled with 0's as no information is included in new observation space. example: obs = array([1, 1, 1, 1, 1, 1, 1, 1, 1, ..., 1, 1, 1]) @@ -190,13 +224,17 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): [ 3, 1, 1, 1], ... [20, 0, 0, 0]]) + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ - # Convert back to more readable, original format reshaped_nodes = obs[:-num_links].reshape(num_nodes, num_services + 2) # Add empty links back and add node ID back - s = np.zeros([reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], dtype=np.int64) + s = np.zeros( + [reshaped_nodes.shape[0] + num_links, reshaped_nodes.shape[1] + 1], + dtype=np.int64, + ) s[:, 0] = range(1, num_nodes + num_links + 1) # Adding ID back s[:num_nodes, 1:] = reshaped_nodes # put values back in new_obs = s @@ -209,14 +247,19 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): return new_obs -def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): - """Return string describing change between two observations +def describe_obs_change( + obs1, obs2, num_nodes=10, num_links=10, num_services=1 +): + """ + Return string describing change between two observations. example: obs_1 = array([[1, 1, 1, 1, 3], [2, 1, 1, 1, 1]]) obs_2 = array([[1, 1, 1, 1, 1], [2, 1, 1, 1, 1]]) output = 'ID 1: SERVICE 2 set to GOOD' + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ obs1 = convert_to_old_obs(obs1, num_nodes, num_links, num_services) obs2 = convert_to_old_obs(obs2, num_nodes, num_links, num_services) @@ -236,20 +279,27 @@ def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): def _describe_obs_change_helper(obs_change, is_link): - """ " - Helper funcion to describe what has changed + """ + Helper funcion to describe what has changed. example: [ 1 -1 -1 -1 1] -> "ID 1: Service 1 changed to GOOD" Handles multiple changes e.g. 'ID 1: SERVICE 1 changed to PATCHING. SERVICE 2 set to GOOD.' + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ # Indexes where a change has occured, not including 0th index - index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] + index_changed = [ + i for i in range(1, len(obs_change)) if obs_change[i] != -1 + ] # Node pol types, Indexes >= 3 are service nodes NodePOLTypes = [ - NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed + NodePOLType(i).name + if i < 3 + else NodePOLType(3).name + " " + str(i - 3) + for i in index_changed ] # Account for hardware states, software sattes and links states = [ @@ -263,8 +313,8 @@ def _describe_obs_change_helper(obs_change, is_link): if not is_link: desc = f"ID {obs_change[0]}:" - for NodePOLType, state in list(zip(NodePOLTypes, states)): - desc = desc + " " + NodePOLType + " changed to " + state + "." + for node_pol_type, state in list(zip(NodePOLTypes, states)): + desc = desc + " " + node_pol_type + " changed to " + state + "." else: desc = f"ID {obs_change[0]}: Link traffic changed to {states[0]}." @@ -273,12 +323,14 @@ def _describe_obs_change_helper(obs_change, is_link): def transform_action_node_enum(action): """ - Convert a node action from readable string format, to enumerated format + Convert a node action from readable string format, to enumerated format. example: [1, 'SERVICE', 'PATCHING', 0] -> [1, 3, 1, 0] - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ action_node_id = action[0] action_node_property = NodePOLType[action[1]].value @@ -291,24 +343,33 @@ def transform_action_node_enum(action): action_service_index = action[3] - new_action = [action_node_id, action_node_property, property_action, action_service_index] + new_action = [ + action_node_id, + action_node_property, + property_action, + action_service_index, + ] return new_action def transform_action_node_readable(action): """ - Convert a node action from enumerated format to readable format + Convert a node action from enumerated format to readable format. example: [1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0] - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ action_node_property = NodePOLType(action[1]).name if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: + elif ( + action_node_property == "OS" or action_node_property == "SERVICE" + ) and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -319,9 +380,11 @@ def transform_action_node_readable(action): def node_action_description(action): """ - Generate string describing a node-based action - """ + Generate string describing a node-based action. + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if isinstance(action[1], (int, np.int64)): # transform action to readable format action = transform_action_node_readable(action) @@ -334,7 +397,9 @@ def node_action_description(action): if property_action == "NONE": return "" if node_property == "OPERATING" or node_property == "OS": - description = f"NODE {node_id}, {node_property}, SET TO {property_action}" + description = ( + f"NODE {node_id}, {node_property}, SET TO {property_action}" + ) elif node_property == "SERVICE": description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" else: @@ -343,34 +408,13 @@ def node_action_description(action): return description -def transform_action_acl_readable(action): - """ - Transform an ACL action to a more readable format - - example: - [0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1] - """ - - action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"} - action_permissions = {0: "DENY", 1: "ALLOW"} - - action_decision = action_decisions[action[0]] - action_permission = action_permissions[action[1]] - - # For IPs, Ports and Protocols, 0 means any, otherwise its just an index - new_action = [action_decision, action_permission] + list(action[2:6]) - for n, val in enumerate(list(action[2:6])): - if val == 0: - new_action[n + 2] = "ANY" - - return new_action - - def transform_action_acl_enum(action): """ - Convert a acl action from readable string format, to enumerated format - """ + Convert acl action from readable str format, to enumerated format. + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ action_decisions = {"NONE": 0, "CREATE": 1, "DELETE": 2} action_permissions = {"DENY": 0, "ALLOW": 1} @@ -388,8 +432,12 @@ def transform_action_acl_enum(action): def acl_action_description(action): - """generate string describing a acl-based action""" + """ + Generate string describing an acl-based action. + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if isinstance(action[0], (int, np.int64)): # transform action to readable format action = transform_action_acl_readable(action) @@ -406,11 +454,13 @@ def acl_action_description(action): def get_node_of_ip(ip, node_dict): """ - Get the node ID of an IP address + Get the node ID of an IP address. node_dict: dictionary of nodes where key is ID, and value is the node (can be ontained from env.nodes) - """ + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ for node_key, node_value in node_dict.items(): node_ip = node_value.ip_address if node_ip == ip: @@ -418,13 +468,16 @@ def get_node_of_ip(ip, node_dict): def is_valid_node_action(action): - """Is the node action an actual valid action + """Is the node action an actual valid action. Only uses information about the action to determine if the action has an effect Does NOT consider: - Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch - Node already being in that state (turning an ON node ON) + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_node_readable(action) @@ -438,7 +491,10 @@ def is_valid_node_action(action): if node_property == "OPERATING" and node_action == "PATCHING": # Operating State cannot PATCH return False - if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]: + if node_property != "OPERATING" and node_action not in [ + "NONE", + "PATCHING", + ]: # Software States can only do Nothing or Patch return False return True @@ -446,13 +502,16 @@ def is_valid_node_action(action): def is_valid_acl_action(action): """ - Is the ACL action an actual valid action + Is the ACL action an actual valid action. Only uses information about the action to determine if the action has an effect Does NOT consider: - Trying to create identical rules - Trying to create a rule which is a subset of another rule (caused by "ANY") + + TODO: Add params and return in docstring. + TODO: Typehint params and return. """ action_r = transform_action_acl_readable(action) @@ -463,7 +522,11 @@ def is_valid_acl_action(action): if action_decision == "NONE": return False - if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": + if ( + action_source_id == action_destination_id + and action_source_id != "ANY" + and action_destination_id != "ANY" + ): # ACL rule towards itself return False if action_permission == "DENY": @@ -475,7 +538,12 @@ def is_valid_acl_action(action): def is_valid_acl_action_extra(action): - """Harsher version of valid acl actions, does not allow action""" + """ + Harsher version of valid acl actions, does not allow action. + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ if is_valid_acl_action(action) is False: return False @@ -494,33 +562,17 @@ def is_valid_acl_action_extra(action): def get_new_action(old_action, action_dict): - """Get new action (e.g. 32) from old action e.g. [1,1,1,0] - - old_action can be either node or acl action type """ + Get new action (e.g. 32) from old action e.g. [1,1,1,0]. + Old_action can be either node or acl action type + + TODO: Add params and return in docstring. + TODO: Typehint params and return. + """ for key, val in action_dict.items(): if list(val) == list(old_action): return key # Not all possible actions are included in dict, only valid action are # if action is not in the dict, its an invalid action so return 0 return 0 - - -def get_action_description(action, action_dict): - """ - Get a string describing/explaining what an action is doing in words - """ - - action_array = action_dict[action] - if len(action_array) == 4: - # node actions have length 4 - action_description = node_action_description(action_array) - elif len(action_array) == 6: - # acl actions have length 6 - action_description = acl_action_description(action_array) - else: - # Should never happen - action_description = "Unrecognised action" - - return action_description diff --git a/src/primaite/cli.py b/src/primaite/cli.py index aa88a391..10e23bfc 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -13,6 +13,8 @@ import yaml from platformdirs import PlatformDirs from typing_extensions import Annotated +from primaite.data_viz import PlotlyTemplate + app = typer.Typer() @@ -54,7 +56,9 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): print(re.sub(r"\n*", "", line)) -_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa +_LogLevel = Enum( + "LogLevel", {k: k for k in logging._levelToName.values()} +) # noqa @app.command() @@ -76,11 +80,12 @@ def log_level(level: Annotated[Optional[_LogLevel], typer.Argument()] = None): primaite_config = yaml.safe_load(file) if level: - primaite_config["log_level"] = level.value + primaite_config["logging"]["log_level"] = level.value with open(user_config_path, "w") as file: yaml.dump(primaite_config, file) + print(f"PrimAITE Log Level: {level}") else: - level = primaite_config["log_level"] + level = primaite_config["logging"]["log_level"] print(f"PrimAITE Log Level: {level}") @@ -170,16 +175,50 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None): ldc: The lay down config file path. Optional. If no value is passed then example default lay down config is used from: - ~/primaite/config/example_config/lay_down/lay_down_config_5_data_manipulation.yaml. + ~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml. """ - from primaite.main import run + from primaite.config.lay_down_config import dos_very_basic_config_path from primaite.config.training_config import main_training_config_path - from primaite.config.lay_down_config import data_manipulation_config_path + from primaite.main import run if not tc: tc = main_training_config_path() if not ldc: - ldc = data_manipulation_config_path() + ldc = dos_very_basic_config_path() run(training_config_path=tc, lay_down_config_path=ldc) + + +@app.command() +def plotly_template( + template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None +): + """ + View or set the plotly template for Session plots. + + To View, simply call: primaite plotly-template + + To set, call: primaite plotly-template + + For example, to set as plotly_dark, call: primaite plotly-template PLOTLY_DARK + """ + app_dirs = PlatformDirs(appname="primaite") + app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) + user_config_path = app_dirs.user_config_path / "primaite_config.yaml" + if user_config_path.exists(): + with open(user_config_path, "r") as file: + primaite_config = yaml.safe_load(file) + + if template: + primaite_config["session"]["outputs"]["plots"][ + "template" + ] = template.value + with open(user_config_path, "w") as file: + yaml.dump(primaite_config, file) + print(f"PrimAITE plotly template: {template.value}") + else: + template = primaite_config["session"]["outputs"]["plots"][ + "template" + ] + print(f"PrimAITE plotly template: {template}") diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index 6a93e1b5..a363a1a0 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -83,6 +83,7 @@ class Protocol(Enum): class SessionType(Enum): """The type of PrimAITE Session to be run.""" + TRAIN = 1 "Train an agent" EVAL = 2 @@ -93,6 +94,7 @@ class SessionType(Enum): class VerboseLevel(IntEnum): """PrimAITE Session Output verbose level.""" + NO_OUTPUT = 0 INFO = 1 DEBUG = 2 @@ -100,6 +102,7 @@ class VerboseLevel(IntEnum): class AgentFramework(Enum): """The agent algorithm framework/package.""" + CUSTOM = 0 "Custom Agent" SB3 = 1 @@ -110,6 +113,7 @@ class AgentFramework(Enum): class DeepLearningFramework(Enum): """The deep learning framework.""" + TF = "tf" "Tensorflow" TF2 = "tf2" @@ -120,6 +124,7 @@ class DeepLearningFramework(Enum): class AgentIdentifier(Enum): """The Red Agent algo/class.""" + A2C = 1 "Advantage Actor Critic" PPO = 2 @@ -136,6 +141,7 @@ class AgentIdentifier(Enum): class HardCodedAgentView(Enum): """The view the deterministic hard-coded agent has of the environment.""" + BASIC = 1 "The current observation space only" FULL = 2 @@ -144,6 +150,7 @@ class HardCodedAgentView(Enum): class ActionType(Enum): """Action type enumeration.""" + NODE = 0 ACL = 1 ANY = 2 @@ -151,6 +158,7 @@ class ActionType(Enum): class ObservationType(Enum): """Observation type enumeration.""" + BOX = 0 MULTIDISCRETE = 1 @@ -193,6 +201,7 @@ class LinkStatus(Enum): class OutputVerboseLevel(IntEnum): """The Agent output verbosity level.""" + NONE = 0 "No Output" INFO = 1 diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 3cccbcae..cc5d4955 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -35,10 +35,10 @@ hard_coded_agent_view: FULL # "NODE" # "ACL" # "ANY" node and acl actions -action_type: NODE +action_type: ANY # Number of episodes to run per session -num_episodes: 1000 +num_episodes: 10 # Number of time_steps per episode num_steps: 256 @@ -47,14 +47,14 @@ num_steps: 256 # Set to 0 if no checkpoints are required. Default is 10 checkpoint_every_n_episodes: 10 -# Time delay between steps (for generic agents) +# Time delay (milliseconds) between steps for CUSTOM agents. time_delay: 5 # Type of session to be run. Options are: # "TRAIN" (Trains an agent) # "EVAL" (Evaluates an agent) # "TRAIN_EVAL" (Trains then evaluates an agent) -session_type: TRAIN +session_type: TRAIN_EVAL # Environment config values # The high value for the observation space diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index 49a33d6e..ae067228 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -1,20 +1,20 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. from pathlib import Path -from typing import Final, Union, Dict, Any +from typing import Any, Dict, Final, Union -import networkx import yaml -from primaite import USERS_CONFIG_DIR, getLogger +from primaite import getLogger, USERS_CONFIG_DIR _LOGGER = getLogger(__name__) -_EXAMPLE_LAY_DOWN: Final[ - Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" +_EXAMPLE_LAY_DOWN: Final[Path] = ( + USERS_CONFIG_DIR / "example_config" / "lay_down" +) def convert_legacy_lay_down_config_dict( - legacy_config_dict: Dict[str, Any] + legacy_config_dict: Dict[str, Any] ) -> Dict[str, Any]: """ Convert a legacy lay down config dict to the new format. @@ -25,10 +25,7 @@ def convert_legacy_lay_down_config_dict( return legacy_config_dict -def load( - file_path: Union[str, Path], - legacy_file: bool = False -) -> Dict: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict: """ Read in a lay down config yaml file. diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 72b5523a..84dd3cc8 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -7,15 +7,22 @@ from typing import Any, Dict, Final, Optional, Union import yaml -from primaite import USERS_CONFIG_DIR, getLogger -from primaite.common.enums import DeepLearningFramework, HardCodedAgentView -from primaite.common.enums import ActionType, AgentIdentifier, \ - AgentFramework, SessionType, OutputVerboseLevel +from primaite import getLogger, USERS_CONFIG_DIR +from primaite.common.enums import ( + ActionType, + AgentFramework, + AgentIdentifier, + DeepLearningFramework, + HardCodedAgentView, + OutputVerboseLevel, + SessionType, +) _LOGGER = getLogger(__name__) -_EXAMPLE_TRAINING: Final[ - Path] = USERS_CONFIG_DIR / "example_config" / "training" +_EXAMPLE_TRAINING: Final[Path] = ( + USERS_CONFIG_DIR / "example_config" / "training" +) def main_training_config_path() -> Path: @@ -36,6 +43,7 @@ def main_training_config_path() -> Path: @dataclass() class TrainingConfig: """The Training Config class.""" + agent_framework: AgentFramework = AgentFramework.SB3 "The AgentFramework" @@ -171,12 +179,16 @@ class TrainingConfig: file_system_scanning_limit: int = 5 "The time taken to scan the file system" - @classmethod def from_dict( - cls, - config_dict: Dict[str, Union[str, int, bool]] + cls, config_dict: Dict[str, Union[str, int, bool]] ) -> TrainingConfig: + """ + Create an instance of TrainingConfig from a dict. + + :param config_dict: The training config dict. + :return: The instance of TrainingConfig. + """ field_enum_map = { "agent_framework": AgentFramework, "deep_learning_framework": DeepLearningFramework, @@ -187,9 +199,9 @@ class TrainingConfig: "hard_coded_agent_view": HardCodedAgentView, } - for field, enum_class in field_enum_map.items(): - if field in config_dict: - config_dict[field] = enum_class[config_dict[field]] + for key, value in field_enum_map.items(): + if key in config_dict: + config_dict[key] = value[config_dict[key]] return TrainingConfig(**config_dict) def to_dict(self, json_serializable: bool = True): @@ -213,23 +225,21 @@ class TrainingConfig: return data def __str__(self) -> str: - tc = f"TrainingConfig(agent_framework={self.agent_framework.name}, " + tc = f"{self.agent_framework}, " if self.agent_framework is AgentFramework.RLLIB: - tc += f"deep_learning_framework=" \ - f"{self.deep_learning_framework.name}, " - tc += f"agent_identifier={self.agent_identifier.name}, " + tc += f"{self.deep_learning_framework}, " + tc += f"{self.agent_identifier}, " if self.agent_identifier is AgentIdentifier.HARDCODED: - tc += f"hard_coded_agent_view={self.hard_coded_agent_view.name}, " - tc += f"action_type={self.action_type.name}, " + tc += f"{self.hard_coded_agent_view}, " + tc += f"{self.action_type}, " tc += f"observation_space={self.observation_space}, " - tc += f"num_episodes={self.num_episodes}, " - tc += f"num_steps={self.num_steps})" + tc += f"{self.num_episodes} episodes @ " + tc += f"{self.num_steps} steps" return tc def load( - file_path: Union[str, Path], - legacy_file: bool = False + file_path: Union[str, Path], legacy_file: bool = False ) -> TrainingConfig: """ Read in a training config yaml file. @@ -273,12 +283,12 @@ def load( def convert_legacy_training_config_dict( - legacy_config_dict: Dict[str, Any], - agent_framework: AgentFramework = AgentFramework.SB3, - agent_identifier: AgentIdentifier = AgentIdentifier.PPO, - action_type: ActionType = ActionType.ANY, - num_steps: int = 256, - output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO + legacy_config_dict: Dict[str, Any], + agent_framework: AgentFramework = AgentFramework.SB3, + agent_identifier: AgentIdentifier = AgentIdentifier.PPO, + action_type: ActionType = ActionType.ANY, + num_steps: int = 256, + output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO, ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -301,8 +311,12 @@ def convert_legacy_training_config_dict( "agent_identifier": agent_identifier.name, "action_type": action_type.name, "num_steps": num_steps, - "output_verbose_level": output_verbose_level + "output_verbose_level": output_verbose_level.name, } + session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"} + legacy_config_dict["sessionType"] = session_type_map[ + legacy_config_dict["sessionType"] + ] for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/data_viz/__init__.py b/src/primaite/data_viz/__init__.py new file mode 100644 index 00000000..a7cc3e8b --- /dev/null +++ b/src/primaite/data_viz/__init__.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class PlotlyTemplate(Enum): + """The built-in plotly templates.""" + + PLOTLY = "plotly" + PLOTLY_WHITE = "plotly_white" + PLOTLY_DARK = "plotly_dark" + GGPLOT2 = "ggplot2" + SEABORN = "seaborn" + SIMPLE_WHITE = "simple_white" + NONE = "none" diff --git a/src/primaite/data_viz/session_plots.py b/src/primaite/data_viz/session_plots.py new file mode 100644 index 00000000..245b9774 --- /dev/null +++ b/src/primaite/data_viz/session_plots.py @@ -0,0 +1,73 @@ +from pathlib import Path +from typing import Dict, Optional, Union + +import plotly.graph_objects as go +import polars as pl +import yaml +from plotly.graph_objs import Figure + +from primaite import _PLATFORM_DIRS + + +def _get_plotly_config() -> Dict: + """Get the plotly config from primaite_config.yaml.""" + user_config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" + with open(user_config_path, "r") as file: + primaite_config = yaml.safe_load(file) + return primaite_config["session"]["outputs"]["plots"] + + +def plot_av_reward_per_episode( + av_reward_per_episode_csv: Union[str, Path], + title: Optional[str] = None, + subtitle: Optional[str] = None, +) -> Figure: + """ + Plot the average reward per episode from a csv session output. + + :param av_reward_per_episode_csv: The average reward per episode csv + file path. + :param title: The plot title. This is optional. + :param subtitle: The plot subtitle. This is optional. + :return: The plot as an instance of ``plotly.graph_objs._figure.Figure``. + """ + df = pl.read_csv(av_reward_per_episode_csv) + + if title: + if subtitle: + title = f"{title}
{subtitle}" + else: + if subtitle: + title = subtitle + + config = _get_plotly_config() + layout = go.Layout( + autosize=config["size"]["auto_size"], + width=config["size"]["width"], + height=config["size"]["height"], + ) + # Create the line graph with a colored line + fig = go.Figure(layout=layout) + fig.update_layout(template=config["template"]) + fig.add_trace( + go.Scatter( + x=df["Episode"], + y=df["Average Reward"], + mode="lines", + name="Mean Reward per Episode", + ) + ) + + # Set the layout of the graph + fig.update_layout( + xaxis={ + "title": "Episode", + "type": "linear", + "rangeslider": {"visible": config["range_slider"]}, + }, + yaxis={"title": "Average Reward"}, + title=title, + showlegend=False, + ) + + return fig diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 9e71ef1b..6893125e 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,7 +1,7 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union +from typing import Dict, Final, List, Tuple, TYPE_CHECKING, Union import numpy as np from gym import spaces @@ -77,7 +77,9 @@ class NodeLinkTable(AbstractObservationComponent): ) # 3. Initialise Observation with zeroes - self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) + self.current_observation = np.zeros( + observation_shape, dtype=self._DATA_TYPE + ) def update(self): """Update the observation based on current environment state. @@ -92,7 +94,9 @@ class NodeLinkTable(AbstractObservationComponent): self.current_observation[item_index][0] = int(node.node_id) self.current_observation[item_index][1] = node.hardware_state.value if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.current_observation[item_index][2] = node.software_state.value + self.current_observation[item_index][ + 2 + ] = node.software_state.value self.current_observation[item_index][ 3 ] = node.file_system_state_observed.value @@ -199,9 +203,16 @@ class NodeStatuses(AbstractObservationComponent): if isinstance(node, ServiceNode): for i, service in enumerate(self.env.services_list): if node.has_service(service): - service_states[i] = node.get_service_state(service).value + service_states[i] = node.get_service_state( + service + ).value obs.extend( - [hardware_state, software_state, file_system_state, *service_states] + [ + hardware_state, + software_state, + file_system_state, + *service_states, + ] ) self.current_observation[:] = obs @@ -259,7 +270,9 @@ class LinkTrafficLevels(AbstractObservationComponent): # 1. Define the shape of your observation space component shape = ( - [self._quantisation_levels] * self.env.num_links * self._entries_per_link + [self._quantisation_levels] + * self.env.num_links + * self._entries_per_link ) # 2. Create Observation space @@ -279,7 +292,9 @@ class LinkTrafficLevels(AbstractObservationComponent): if self._combine_service_traffic: loads = [link.get_current_load()] else: - loads = [protocol.get_load() for protocol in link.protocol_list] + loads = [ + protocol.get_load() for protocol in link.protocol_list + ] for load in loads: if load <= 0: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e43dc8a5..ea8f82d4 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -2,7 +2,7 @@ """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy from pathlib import Path -from typing import Dict, Tuple, Union, Final +from typing import Dict, Final, Tuple, Union import networkx as nx import numpy as np @@ -12,8 +12,10 @@ from matplotlib import pyplot as plt from primaite import getLogger from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import is_valid_acl_action_extra, \ - is_valid_node_action +from primaite.agents.utils import ( + is_valid_acl_action_extra, + is_valid_node_action, +) from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, @@ -24,7 +26,8 @@ from primaite.common.enums import ( NodeType, ObservationType, Priority, - SoftwareState, SessionType, + SessionType, + SoftwareState, ) from primaite.common.service import Service from primaite.config import training_config @@ -34,15 +37,18 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import \ - NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import ( + NodeStateInstructionGreen, +) from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import apply_red_agent_iers, \ - apply_red_agent_node_pol +from primaite.pol.red_agent_pol import ( + apply_red_agent_iers, + apply_red_agent_node_pol, +) from primaite.transactions.transaction import Transaction from primaite.utils.session_output_writer import SessionOutputWriter @@ -59,11 +65,11 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -237,27 +243,19 @@ class Primaite(Env): ) self.episode_av_reward_writer = SessionOutputWriter( - self, - transaction_writer=False, - learning_session=True + self, transaction_writer=False, learning_session=True ) self.transaction_writer = SessionOutputWriter( - self, - transaction_writer=True, - learning_session=True + self, transaction_writer=True, learning_session=True ) def set_as_eval(self): """Set the writers to write to eval directories.""" self.episode_av_reward_writer = SessionOutputWriter( - self, - transaction_writer=False, - learning_session=False + self, transaction_writer=False, learning_session=False ) self.transaction_writer = SessionOutputWriter( - self, - transaction_writer=True, - learning_session=False + self, transaction_writer=True, learning_session=False ) self.episode_count = 0 self.step_count = 0 @@ -322,9 +320,7 @@ class Primaite(Env): # Create a Transaction (metric) object for this step transaction = Transaction( - self.agent_identifier, - self.episode_count, - self.step_count + self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction transaction.obs_space_pre = copy.deepcopy(self.env_obs) @@ -354,8 +350,9 @@ class Primaite(Env): self.nodes_post_pol = copy.deepcopy(self.nodes) self.links_post_pol = copy.deepcopy(self.links) # Reference - apply_node_pol(self.nodes_reference, self.node_pol, - self.step_count) # Node PoL + apply_node_pol( + self.nodes_reference, self.node_pol, self.step_count + ) # Node PoL apply_iers( self.network_reference, self.nodes_reference, @@ -404,8 +401,10 @@ class Primaite(Env): # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - _LOGGER.info(f"Episode: {self.episode_count}, " - f"Average Reward: {self.average_reward}") + _LOGGER.info( + f"Episode: {self.episode_count}, " + f"Average Reward: {self.average_reward}" + ) # Load the reward into the transaction transaction.reward = reward @@ -452,11 +451,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -525,7 +524,7 @@ class Primaite(Env): # Patch (valid action if it's good or compromised) node.set_service_state( self.services_list[service_index], - SoftwareState.PATCHING + SoftwareState.PATCHING, ) else: # Node is not of Service Type @@ -542,7 +541,10 @@ class Primaite(Env): elif property_action == 2: # Repair # You cannot repair a destroyed file system - it needs restoring - if node.file_system_state_actual != FileSystemState.DESTROYED: + if ( + node.file_system_state_actual + != FileSystemState.DESTROYED + ): node.set_file_system_state(FileSystemState.REPAIRING) elif property_action == 3: # Restore @@ -585,8 +587,9 @@ class Primaite(Env): acl_rule_source = "ANY" else: node = list(self.nodes.values())[action_source_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, - ActiveNode): + if isinstance(node, ServiceNode) or isinstance( + node, ActiveNode + ): acl_rule_source = node.ip_address else: return @@ -595,8 +598,9 @@ class Primaite(Env): acl_rule_destination = "ANY" else: node = list(self.nodes.values())[action_destination_ip - 1] - if isinstance(node, ServiceNode) or isinstance(node, - ActiveNode): + if isinstance(node, ServiceNode) or isinstance( + node, ActiveNode + ): acl_rule_destination = node.ip_address else: return @@ -681,8 +685,9 @@ class Primaite(Env): :return: The observation space, initial observation (zeroed out array with the correct shape) :rtype: Tuple[spaces.Space, np.ndarray] """ - self.obs_handler = ObservationsHandler.from_config(self, - self.obs_config) + self.obs_handler = ObservationsHandler.from_config( + self, self.obs_config + ) return self.obs_handler.space, self.obs_handler.current_observation @@ -790,7 +795,8 @@ class Primaite(Env): service_port = service["port"] service_state = SoftwareState[service["state"]] node.add_service( - Service(service_protocol, service_port, service_state)) + Service(service_protocol, service_port, service_state) + ) else: # Bad formatting pass @@ -843,8 +849,9 @@ class Primaite(Env): dest_node_ref: Node = self.nodes_reference[link_destination] # Add link to network (reference) - self.network_reference.add_edge(source_node_ref, dest_node_ref, - id=link_name) + self.network_reference.add_edge( + source_node_ref, dest_node_ref, id=link_name + ) # Add link to link dictionary (reference) self.links_reference[link_name] = Link( @@ -1120,7 +1127,8 @@ class Primaite(Env): node_id = item["node_id"] node_class = item["node_class"] node_hardware_state: HardwareState = HardwareState[ - item["hardware_state"]] + item["hardware_state"] + ] node: NodeUnion = self.nodes[node_id] node_ref = self.nodes_reference[node_id] @@ -1186,8 +1194,12 @@ class Primaite(Env): # Use MAX to ensure we get them all for node_action in range(4): for service_state in range(self.num_services): - action = [node, node_property, node_action, - service_state] + action = [ + node, + node_property, + node_action, + service_state, + ] # check to see if it's a nothing action (has no effect) if is_valid_node_action(action): actions[action_key] = action diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 00e45fa3..4dd0550e 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -46,7 +46,9 @@ def calculate_reward_function( ) # Software State - if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode): + if isinstance(final_node, ActiveNode) or isinstance( + final_node, ServiceNode + ): reward_value += score_node_os_state( final_node, initial_node, reference_node, config_values ) @@ -81,7 +83,8 @@ def calculate_reward_function( reference_blocked = not reference_ier.get_is_running() live_blocked = not ier_value.get_is_running() ier_reward = ( - config_values.green_ier_blocked * ier_value.get_mission_criticality() + config_values.green_ier_blocked + * ier_value.get_mission_criticality() ) if live_blocked and not reference_blocked: @@ -104,7 +107,9 @@ def calculate_reward_function( return reward_value -def score_node_operating_state(final_node, initial_node, reference_node, config_values): +def score_node_operating_state( + final_node, initial_node, reference_node, config_values +): """ Calculates score relating to the hardware state of a node. @@ -153,7 +158,9 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_ return score -def score_node_os_state(final_node, initial_node, reference_node, config_values): +def score_node_os_state( + final_node, initial_node, reference_node, config_values +): """ Calculates score relating to the Software State of a node. @@ -204,7 +211,9 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values) return score -def score_node_service_state(final_node, initial_node, reference_node, config_values): +def score_node_service_state( + final_node, initial_node, reference_node, config_values +): """ Calculates score relating to the service state(s) of a node. @@ -276,7 +285,9 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va return score -def score_node_file_system(final_node, initial_node, reference_node, config_values): +def score_node_file_system( + final_node, initial_node, reference_node, config_values +): """ Calculates score relating to the file system state of a node. diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index 90235e9f..054f4c34 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -8,7 +8,9 @@ from primaite.common.protocol import Protocol class Link(object): """Link class.""" - def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services): + def __init__( + self, _id, _bandwidth, _source_node_name, _dest_node_name, _services + ): """ Init. diff --git a/src/primaite/main.py b/src/primaite/main.py index 3c0f93b3..556c5ec3 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -10,7 +10,10 @@ from primaite.primaite_session import PrimaiteSession _LOGGER = getLogger(__name__) -def run(training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]): +def run( + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], +): """Run the PrimAITE Session. :param training_config_path: The training config filepath. diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index 57fa4c68..b1c3f57c 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -87,7 +87,9 @@ class ActiveNode(Node): f"Node.software_state:{self._software_state}" ) - def set_software_state_if_not_compromised(self, software_state: SoftwareState): + def set_software_state_if_not_compromised( + self, software_state: SoftwareState + ): """ Sets Software State if the node is not compromised. @@ -98,7 +100,9 @@ class ActiveNode(Node): if self._software_state != SoftwareState.COMPROMISED: self._software_state = software_state if software_state == SoftwareState.PATCHING: - self.patching_count = self.config_values.os_patching_duration + self.patching_count = ( + self.config_values.os_patching_duration + ) else: _LOGGER.info( f"The Nodes hardware state is OFF so OS State cannot be changed." @@ -187,7 +191,9 @@ class ActiveNode(Node): def start_file_system_scan(self): """Starts a file system scan.""" self.file_system_scanning = True - self.file_system_scanning_count = self.config_values.file_system_scanning_limit + self.file_system_scanning_count = ( + self.config_values.file_system_scanning_limit + ) def update_file_system_state(self): """Updates file system status based on scanning/restore/repair cycle.""" @@ -206,7 +212,10 @@ class ActiveNode(Node): self.file_system_state_observed = FileSystemState.GOOD # Scanning updates - if self.file_system_scanning == True and self.file_system_scanning_count < 0: + if ( + self.file_system_scanning == True + and self.file_system_scanning_count < 0 + ): self.file_system_state_observed = self.file_system_state_actual self.file_system_scanning = False self.file_system_scanning_count = 0 diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 2b1d94be..04681807 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -32,7 +32,9 @@ class NodeStateInstructionGreen(object): self.end_step = _end_step self.node_id = _node_id self.node_pol_type = _node_pol_type - self.service_name = _service_name # Not used when not a service instruction + self.service_name = ( + _service_name # Not used when not a service instruction + ) self.state = _state def get_start_step(self): diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 7f62fe24..ba35067c 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -42,7 +42,9 @@ class NodeStateInstructionRed(object): self.target_node_id = _target_node_id self.initiator = _pol_initiator self.pol_type: NodePOLType = _pol_type - self.service_name = pol_protocol # Not used when not a service instruction + self.service_name = ( + pol_protocol # Not used when not a service instruction + ) self.state = _pol_state self.source_node_id = _pol_source_node_id self.source_node_service = _pol_source_node_service diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 324592c3..6dcff73e 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -110,7 +110,9 @@ class ServiceNode(ActiveNode): return False return False - def set_service_state(self, protocol_name: str, software_state: SoftwareState): + def set_service_state( + self, protocol_name: str, software_state: SoftwareState + ): """ Sets the software_state of a service (protocol) on the node. diff --git a/src/primaite/notebooks/__init__.py b/src/primaite/notebooks/__init__.py index 71ed343e..0e81e581 100644 --- a/src/primaite/notebooks/__init__.py +++ b/src/primaite/notebooks/__init__.py @@ -4,7 +4,7 @@ import os import subprocess import sys -from primaite import NOTEBOOKS_DIR, getLogger +from primaite import getLogger, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index 1d05dc3f..aeae7add 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -6,10 +6,17 @@ from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState +from primaite.common.enums import ( + HardwareState, + NodePOLType, + NodeType, + SoftwareState, +) from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode -from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen +from primaite.nodes.node_state_instruction_green import ( + NodeStateInstructionGreen, +) from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.service_node import ServiceNode from primaite.pol.ier import IER @@ -190,7 +197,9 @@ def apply_iers( link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth - if (link.get_current_load() + load) > link.get_bandwidth(): + if ( + link.get_current_load() + load + ) > link.get_bandwidth(): link_capacity_exceeded = True if _VERBOSE: print("Link capacity exceeded") @@ -204,7 +213,8 @@ def apply_iers( while count < path_node_list_length - 1: # Get the link between the next two nodes edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] + path_node_list[count], + path_node_list[count + 1], ) link_id = edge_dict[0].get("id") link = links[link_id] @@ -216,7 +226,9 @@ def apply_iers( else: # One of the nodes is not operational if _VERBOSE: - print("Path not valid - one or more nodes not operational") + print( + "Path not valid - one or more nodes not operational" + ) pass else: @@ -231,7 +243,9 @@ def apply_iers( def apply_node_pol( nodes: Dict[str, NodeUnion], - node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], + node_pol: Dict[ + any, Union[NodeStateInstructionGreen, NodeStateInstructionRed] + ], step: int, ): """ @@ -263,16 +277,22 @@ def apply_node_pol( elif node_pol_type == NodePOLType.OS: # Change OS state # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + if isinstance(node, ActiveNode) or isinstance( + node, ServiceNode + ): node.set_software_state_if_not_compromised(state) elif node_pol_type == NodePOLType.SERVICE: # Change a service state # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this if isinstance(node, ServiceNode): - node.set_service_state_if_not_compromised(service_name, state) + node.set_service_state_if_not_compromised( + service_name, state + ) else: # Change the file system status - if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): + if isinstance(node, ActiveNode) or isinstance( + node, ServiceNode + ): node.set_file_system_state_if_not_compromised(state) else: # PoL is not valid in this time step diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index b23992e7..96fe787c 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -176,7 +176,9 @@ def apply_red_agent_iers( link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth - if (link.get_current_load() + load) > link.get_bandwidth(): + if ( + link.get_current_load() + load + ) > link.get_bandwidth(): link_capacity_exceeded = True if _VERBOSE: print("Link capacity exceeded") @@ -190,7 +192,8 @@ def apply_red_agent_iers( while count < path_node_list_length - 1: # Get the link between the next two nodes edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] + path_node_list[count], + path_node_list[count + 1], ) link_id = edge_dict[0].get("id") link = links[link_id] @@ -200,16 +203,23 @@ def apply_red_agent_iers( # This IER is now valid, so set it to running ier_value.set_is_running(True) if _VERBOSE: - print("Red IER was allowed to run in step " + str(step)) + print( + "Red IER was allowed to run in step " + + str(step) + ) else: # One of the nodes is not operational if _VERBOSE: - print("Path not valid - one or more nodes not operational") + print( + "Path not valid - one or more nodes not operational" + ) pass else: if _VERBOSE: - print("Red IER was NOT allowed to run in step " + str(step)) + print( + "Red IER was NOT allowed to run in step " + str(step) + ) print("Source, Dest or ACL were not valid") pass # ------------------------------------ @@ -264,7 +274,9 @@ def apply_red_agent_node_pol( passed_checks = True elif initiator == NodePOLInitiator.IER: # Need to check there is a red IER incoming - passed_checks = is_red_ier_incoming(target_node, iers, pol_type) + passed_checks = is_red_ier_incoming( + target_node, iers, pol_type + ) elif initiator == NodePOLInitiator.SERVICE: # Need to check the condition of a service on another node source_node = nodes[source_node_id] @@ -308,7 +320,9 @@ def apply_red_agent_node_pol( target_node.set_file_system_state(state) else: if _VERBOSE: - print("Node Red Agent PoL not allowed - did not pass checks") + print( + "Node Red Agent PoL not allowed - did not pass checks" + ) else: # PoL is not valid in this time step pass @@ -323,7 +337,10 @@ def is_red_ier_incoming(node, iers, node_pol_type): node_id = node.node_id for ier_key, ier_value in iers.items(): - if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id: + if ( + ier_value.get_is_running() + and ier_value.get_dest_node_id() == node_id + ): if ( node_pol_type == NodePOLType.OPERATING or node_pol_type == NodePOLType.OS diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index cd959be0..4d8d3022 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -1,54 +1,51 @@ from __future__ import annotations -import json -from datetime import datetime from pathlib import Path -from typing import Final, Optional, Union, Dict -from uuid import uuid4 +from typing import Dict, Final, Optional, Union -from primaite import getLogger, SESSIONS_DIR +from primaite import getLogger from primaite.agents.agent import AgentSessionABC from primaite.agents.hardcoded_acl import HardCodedACLAgent from primaite.agents.hardcoded_node import HardCodedNodeAgent from primaite.agents.rllib import RLlibAgent from primaite.agents.sb3 import SB3Agent -from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, \ - RandomAgent, DummyAgent -from primaite.common.enums import AgentFramework, AgentIdentifier, \ - ActionType, SessionType +from primaite.agents.simple import ( + DoNothingACLAgent, + DoNothingNodeAgent, + DummyAgent, + RandomAgent, +) +from primaite.common.enums import ( + ActionType, + AgentFramework, + AgentIdentifier, + SessionType, +) from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig -from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) -def _get_session_path(session_timestamp: datetime) -> Path: - """ - Get the directory path the session will output to. - - This is set in the format of: - ~/primaite/sessions//_. - - :param session_timestamp: This is the datetime that the session started. - :return: The session directory path. - """ - date_dir = session_timestamp.strftime("%Y-%m-%d") - session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = SESSIONS_DIR / date_dir / session_path - session_path.mkdir(exist_ok=True, parents=True) - _LOGGER.debug(f"Created PrimAITE Session path: {session_path}") - - return session_path - - class PrimaiteSession: + """ + The PrimaiteSession class. + + Provides a single learning and evaluation entry point for all training + and lay down configurations. + """ def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path] + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], ): + """ + The PrimaiteSession constructor. + + :param training_config_path: The training config path. + :param lay_down_config_path: The lay down config path. + """ if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path @@ -64,22 +61,35 @@ class PrimaiteSession: ) self._agent_session: AgentSessionABC = None # noqa + self.session_path: Path = None # noqa + self.timestamp_str: str = None # noqa + self.learning_path: Path = None # noqa + self.evaluation_path: Path = None # noqa def setup(self): + """Performs the session setup.""" if self._training_config.agent_framework == AgentFramework.CUSTOM: - if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}" + ) + if ( + self._training_config.agent_identifier + == AgentIdentifier.HARDCODED + ): + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Identifier =" + f" {AgentIdentifier.HARDCODED}" + ) if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space self._agent_session = HardCodedNodeAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space self._agent_session = HardCodedACLAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.action_type == ActionType.ANY: @@ -90,18 +100,23 @@ class PrimaiteSession: # Invalid AgentIdentifier ActionType combo raise ValueError - elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: + elif ( + self._training_config.agent_identifier + == AgentIdentifier.DO_NOTHING + ): + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Identifier =" + f" {AgentIdentifier.DO_NOTHINGD}" + ) if self._training_config.action_type == ActionType.NODE: self._agent_session = DoNothingNodeAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space self._agent_session = DoNothingACLAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.action_type == ActionType.ANY: @@ -112,15 +127,26 @@ class PrimaiteSession: # Invalid AgentIdentifier ActionType combo raise ValueError - elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: - self._agent_session = RandomAgent( - self._training_config_path, - self._lay_down_config_path + elif ( + self._training_config.agent_identifier + == AgentIdentifier.RANDOM + ): + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Identifier =" + f" {AgentIdentifier.RANDOM}" + ) + self._agent_session = RandomAgent( + self._training_config_path, self._lay_down_config_path + ) + elif ( + self._training_config.agent_identifier == AgentIdentifier.DUMMY + ): + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Identifier =" + f" {AgentIdentifier.DUMMY}" ) - elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: self._agent_session = DummyAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) else: @@ -128,37 +154,64 @@ class PrimaiteSession: raise ValueError elif self._training_config.agent_framework == AgentFramework.SB3: + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}" + ) # Stable Baselines3 Agent self._agent_session = SB3Agent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) elif self._training_config.agent_framework == AgentFramework.RLLIB: + _LOGGER.debug( + f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}" + ) # Ray RLlib Agent self._agent_session = RLlibAgent( - self._training_config_path, - self._lay_down_config_path + self._training_config_path, self._lay_down_config_path ) else: # Invalid AgentFramework raise ValueError + self.session_path: Path = self._agent_session.session_path + self.timestamp_str: str = self._agent_session.timestamp_str + self.learning_path: Path = self._agent_session.learning_path + self.evaluation_path: Path = self._agent_session.evaluation_path + def learn( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Train the agent. + + :param time_steps: The number of time steps per episode. + :param episodes: The number of episodes. + :param kwargs: Any agent-framework specific key word args. + """ if not self._training_config.session_type == SessionType.EVAL: self._agent_session.learn(time_steps, episodes, **kwargs) def evaluate( - self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, - **kwargs + self, + time_steps: Optional[int] = None, + episodes: Optional[int] = None, + **kwargs, ): + """ + Evaluate the agent. + + :param time_steps: The number of time steps per episode. + :param episodes: The number of episodes. + :param kwargs: Any agent-framework specific key word args. + """ if not self._training_config.session_type == SessionType.TRAIN: self._agent_session.evaluate(time_steps, episodes, **kwargs) + + def close(self): + """Closes the agent.""" + self._agent_session.close() diff --git a/src/primaite/setup/_package_data/primaite_config.yaml b/src/primaite/setup/_package_data/primaite_config.yaml index 1dd8775b..b9e0d73c 100644 --- a/src/primaite/setup/_package_data/primaite_config.yaml +++ b/src/primaite/setup/_package_data/primaite_config.yaml @@ -9,3 +9,14 @@ logging: WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s' + +# Session +session: + outputs: + plots: + size: + auto_size: false + width: 1500 + height: 900 + template: plotly_white + range_slider: false diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 5192c48f..59eaf8cc 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -6,7 +6,7 @@ from pathlib import Path import pkg_resources -from primaite import NOTEBOOKS_DIR, getLogger +from primaite import getLogger, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) @@ -24,7 +24,9 @@ def run(overwrite_existing: bool = True): for subdir, dirs, files in os.walk(notebooks_package_data_root): for file in files: fp = os.path.join(subdir, file) - path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep) + path_split = os.path.relpath( + fp, notebooks_package_data_root + ).split(os.sep) target_fp = NOTEBOOKS_DIR / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index f4166c6a..f2b4a18f 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -5,7 +5,7 @@ from pathlib import Path import pkg_resources -from primaite import USERS_CONFIG_DIR, getLogger +from primaite import getLogger, USERS_CONFIG_DIR _LOGGER = getLogger(__name__) @@ -24,7 +24,9 @@ def run(overwrite_existing=True): for subdir, dirs, files in os.walk(configs_package_data_root): for file in files: fp = os.path.join(subdir, file) - path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep) + path_split = os.path.relpath(fp, configs_package_data_root).split( + os.sep + ) target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() diff --git a/src/primaite/setup/setup_app_dirs.py b/src/primaite/setup/setup_app_dirs.py index 9f6e8a13..693b11c1 100644 --- a/src/primaite/setup/setup_app_dirs.py +++ b/src/primaite/setup/setup_app_dirs.py @@ -1,5 +1,5 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -from primaite import _USER_DIRS, LOG_DIR, NOTEBOOKS_DIR, getLogger +from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR _LOGGER = getLogger(__name__) diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 6e5ba5f0..1a71f0ff 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -7,12 +7,7 @@ from typing import List, Tuple class Transaction(object): """Transaction class.""" - def __init__( - self, - agent_identifier, - episode_number, - step_number - ): + def __init__(self, agent_identifier, episode_number, step_number): """ Transaction constructor. @@ -37,6 +32,11 @@ class Transaction(object): "The action space invoked by the agent" def as_csv_data(self) -> Tuple[List, List]: + """ + Converts the Transaction to a csv data row and provides a header. + + :return: A tuple consisting of (header, data). + """ if isinstance(self.action_space, int): action_length = self.action_space else: @@ -74,12 +74,14 @@ class Transaction(object): str(self.reward), ] row = ( - row - + _turn_action_space_to_array(self.action_space) - + _turn_obs_space_to_array(self.obs_space_pre, obs_assets, - obs_features) - + _turn_obs_space_to_array(self.obs_space_post, obs_assets, - obs_features) + row + + _turn_action_space_to_array(self.action_space) + + _turn_obs_space_to_array( + self.obs_space_pre, obs_assets, obs_features + ) + + _turn_obs_space_to_array( + self.obs_space_post, obs_assets, obs_features + ) ) return header, row diff --git a/src/primaite/utils/session_output_reader.py b/src/primaite/utils/session_output_reader.py new file mode 100644 index 00000000..d04f375e --- /dev/null +++ b/src/primaite/utils/session_output_reader.py @@ -0,0 +1,20 @@ +from pathlib import Path +from typing import Dict, Union + +# Using polars as it's faster than Pandas; it will speed things up when +# files get big! +import polars as pl + + +def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]: + """ + Read an average rewards per episode csv file and return as a dict. + + The dictionary keys are the episode number, and the values are the mean + reward that episode. + + :param av_rewards_csv_file: The average rewards per episode csv file path. + :return: The average rewards per episode cdv as a dict. + """ + d = pl.read_csv(av_rewards_csv_file).to_dict() + return {v: d["Average Reward"][i] for i, v in enumerate(d["Episode"])} diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index 308e1fb3..86c5ca28 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -1,7 +1,6 @@ import csv from logging import Logger -from typing import List, Final, IO, Union, Tuple -from typing import TYPE_CHECKING +from typing import Final, List, Tuple, TYPE_CHECKING, Union from primaite import getLogger from primaite.transactions.transaction import Transaction @@ -13,15 +12,22 @@ _LOGGER: Logger = getLogger(__name__) class SessionOutputWriter: + """ + A session output writer class. + + Is used to write session outputs to csv file. + """ + _AV_REWARD_PER_EPISODE_HEADER: Final[List[str]] = [ - "Episode", "Average Reward" + "Episode", + "Average Reward", ] def __init__( - self, - env: "Primaite", - transaction_writer: bool = False, - learning_session: bool = True + self, + env: "Primaite", + transaction_writer: bool = False, + learning_session: bool = True, ): self._env = env self.transaction_writer = transaction_writer @@ -52,14 +58,21 @@ class SessionOutputWriter: self._csv_writer = csv.writer(self._csv_file) def __del__(self): + self.close() + + def close(self): + """Close the cvs file.""" if self._csv_file: self._csv_file.close() - _LOGGER.info(f"Finished writing file: {self._csv_file_path}") + _LOGGER.debug(f"Finished writing file: {self._csv_file_path}") - def write( - self, - data: Union[Tuple, Transaction] - ): + def write(self, data: Union[Tuple, Transaction]): + """ + Write a row of session data. + + :param data: The row of data to write. Can be a Tuple or an instance + of Transaction. + """ if isinstance(data, Transaction): header, data = data.as_csv_data() else: @@ -69,5 +82,4 @@ class SessionOutputWriter: self._init_csv_writer() self._csv_writer.writerow(header) self._first_write = False - self._csv_writer.writerow(data) diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy/new_training_config.yaml index 9fdf9a05..49e6a00b 100644 --- a/tests/config/legacy/new_training_config.yaml +++ b/tests/config/legacy/new_training_config.yaml @@ -6,7 +6,7 @@ # "SB3" (Stable Baselines3) # "RLLIB" (Ray[RLlib]) # "NONE" (Custom Agent) -agent_framework: RLLIB +agent_framework: SB3 # Sets which Red Agent algo/class will be used: # "PPO" (Proximal Policy Optimization) @@ -27,7 +27,7 @@ num_steps: 256 # Time delay between steps (for generic agents) time_delay: 10 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml index 67aaa9de..d26d7955 100644 --- a/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml +++ b/tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: SB3 + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: A2C -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -28,7 +39,7 @@ observation_space: time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml index 29a89b8d..aae740b6 100644 --- a/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml +++ b/tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -24,7 +35,7 @@ observation_space: time_delay: 1 # Filename of the scenario / laydown -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml index 8f2d9a38..4066eace 100644 --- a/tests/config/obs_tests/main_config_NODE_STATUSES.yaml +++ b/tests/config/obs_tests/main_config_NODE_STATUSES.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -25,7 +36,7 @@ observation_space: time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/obs_tests/main_config_without_obs.yaml b/tests/config/obs_tests/main_config_without_obs.yaml index e8bb49ea..08452dda 100644 --- a/tests/config/obs_tests/main_config_without_obs.yaml +++ b/tests/config/obs_tests/main_config_without_obs.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: NONE # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 5 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: TRAIN # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/one_node_states_on_off_main_config.yaml b/tests/config/one_node_states_on_off_main_config.yaml index 2e752bc9..7f1ced01 100644 --- a/tests/config/one_node_states_on_off_main_config.yaml +++ b/tests/config/one_node_states_on_off_main_config.yaml @@ -1,10 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: DUMMY -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +30,7 @@ num_steps: 15 time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml index 5c5db582..97d0ddaf 100644 --- a/tests/config/single_action_space_fixed_blue_actions_main_config.yaml +++ b/tests/config/single_action_space_fixed_blue_actions_main_config.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 15 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/config/single_action_space_main_config.yaml b/tests/config/single_action_space_main_config.yaml index 967fdcce..067b9a6d 100644 --- a/tests/config/single_action_space_main_config.yaml +++ b/tests/config/single_action_space_main_config.yaml @@ -1,11 +1,22 @@ -# Main Config File +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: RANDOM -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -18,7 +29,7 @@ num_steps: 5 # Time delay between steps (for generic agents) time_delay: 1 # Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING +session_type: EVAL # Determine whether to load an agent from file load_agent: False # File path and file name of agent if you're loading one in diff --git a/tests/conftest.py b/tests/conftest.py index 945d23f0..41dc5e77 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,37 +1,151 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import datetime +import shutil import tempfile import time from datetime import datetime from pathlib import Path -from typing import Union +from typing import Dict, Union +from unittest.mock import patch +import pytest + +from primaite import getLogger +from primaite.common.enums import AgentIdentifier from primaite.environment.primaite_env import Primaite +from primaite.primaite_session import PrimaiteSession +from primaite.utils.session_output_reader import av_rewards_dict +from tests.mock_and_patch.get_session_path_mock import get_temp_session_path ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1 +_LOGGER = getLogger(__name__) -def _get_temp_session_path(session_timestamp: datetime) -> Path: + +class TempPrimaiteSession(PrimaiteSession): + """ + A temporary PrimaiteSession class. + + Uses context manager for deletion of files upon exit. + """ + + def __init__( + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + ): + super().__init__(training_config_path, lay_down_config_path) + self.setup() + + def learn_av_reward_per_episode(self) -> Dict[int, float]: + """Get the learn av reward per episode from file.""" + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + return av_rewards_dict(self.learning_path / csv_file) + + def eval_av_reward_per_episode_csv(self) -> Dict[int, float]: + """Get the eval av reward per episode from file.""" + csv_file = f"average_reward_per_episode_{self.timestamp_str}.csv" + return av_rewards_dict(self.evaluation_path / csv_file) + + @property + def env(self) -> Primaite: + """Direct access to the env for ease of testing.""" + return self._agent_session._env # noqa + + def __enter__(self): + return self + + def __exit__(self, type, value, tb): + del self._agent_session._env.episode_av_reward_writer + del self._agent_session._env.transaction_writer + shutil.rmtree(self.session_path) + shutil.rmtree(self.session_path.parent) + _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") + + +@pytest.fixture +def temp_primaite_session(request): + """ + Provides a temporary PrimaiteSession instance. + + It's temporary as it uses a temporary directory as the session path. + + To use this fixture you need to: + + - parametrize your test function with: + + - "temp_primaite_session" + - [[path to training config, path to lay down config]] + - Include the temp_primaite_session fixture as a param in your test + function. + - use the temp_primaite_session as a context manager assigning is the + name 'session'. + + .. code:: python + + from primaite.config.lay_down_config import dos_very_basic_config_path + from primaite.config.training_config import main_training_config_path + @pytest.mark.parametrize( + "temp_primaite_session", + [ + [main_training_config_path(), dos_very_basic_config_path()] + ], + indirect=True + ) + def test_primaite_session(temp_primaite_session): + with temp_primaite_session as session: + # Learning outputs are saved in session.learning_path + session.learn() + + # Evaluation outputs are saved in session.evaluation_path + session.evaluate() + + # To ensure that all files are written, you must call .close() + session.close() + + # If you need to inspect any session outputs, it must be done + # inside the context manager + + # Now that we've exited the context manager, the + # session.session_path directory and its contents are deleted + """ + training_config_path = request.param[0] + lay_down_config_path = request.param[1] + with patch( + "primaite.agents.agent.get_session_path", get_temp_session_path + ) as mck: + mck.session_timestamp = datetime.now() + + return TempPrimaiteSession(training_config_path, lay_down_config_path) + + +@pytest.fixture +def temp_session_path() -> Path: """ Get a temp directory session path the test session will output to. - :param session_timestamp: This is the datetime that the session started. :return: The session directory path. """ + session_timestamp = datetime.now() date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + session_path = ( + Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + ) session_path.mkdir(exist_ok=True, parents=True) return session_path def _get_primaite_env_from_config( - training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path] + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + temp_session_path, ): """Takes a config path and returns the created instance of Primaite.""" session_timestamp: datetime = datetime.now() - session_path = _get_temp_session_path(session_timestamp) + session_path = temp_session_path(session_timestamp) timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( @@ -45,7 +159,7 @@ def _get_primaite_env_from_config( # TOOD: This needs t be refactored to happen outside. Should be part of # a main Session class. - if env.training_config.agent_identifier == "GENERIC": + if env.training_config.agent_identifier is AgentIdentifier.RANDOM: run_generic(env, config_values) return env diff --git a/tests/e2e_integration_tests/test_primaite_main.py b/tests/e2e_integration_tests/test_primaite_main.py deleted file mode 100644 index b457557a..00000000 --- a/tests/e2e_integration_tests/test_primaite_main.py +++ /dev/null @@ -1,8 +0,0 @@ -from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.config.training_config import main_training_config_path -from primaite.main import run - - -def test_primaite_main_e2e(): - """Tests the primaite.main.run function end-to-end.""" - run(main_training_config_path(), data_manipulation_config_path()) diff --git a/tests/e2e_integration_tests/__init__.py b/tests/mock_and_patch/__init__.py similarity index 100% rename from tests/e2e_integration_tests/__init__.py rename to tests/mock_and_patch/__init__.py diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py new file mode 100644 index 00000000..cfcfb8f0 --- /dev/null +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -0,0 +1,24 @@ +import tempfile +from datetime import datetime +from pathlib import Path + +from primaite import getLogger + +_LOGGER = getLogger(__name__) + + +def get_temp_session_path(session_timestamp: datetime) -> Path: + """ + Get a temp directory session path the test session will output to. + + :param session_timestamp: This is the datetime that the session started. + :return: The session directory path. + """ + date_dir = session_timestamp.strftime("%Y-%m-%d") + session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") + session_path = ( + Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path + ) + session_path.mkdir(exist_ok=True, parents=True) + _LOGGER.debug(f"Created temp session directory: {session_path}") + return session_path diff --git a/tests/test_active_node.py b/tests/test_active_node.py index addc595c..b6833182 100644 --- a/tests/test_active_node.py +++ b/tests/test_active_node.py @@ -60,7 +60,9 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state): 1, ) - active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED) + active_node.set_software_state_if_not_compromised( + SoftwareState.OVERWHELMED + ) assert active_node.software_state == expected_state @@ -98,7 +100,9 @@ def test_file_system_change(operating_state, expected_state): (HardwareState.ON, FileSystemState.CORRUPT), ], ) -def test_file_system_change_if_not_compromised(operating_state, expected_state): +def test_file_system_change_if_not_compromised( + operating_state, expected_state +): """ Test that a node cannot change its file system state. @@ -116,6 +120,8 @@ def test_file_system_change_if_not_compromised(operating_state, expected_state): 1, ) - active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT) + active_node.set_file_system_state_if_not_compromised( + FileSystemState.CORRUPT + ) assert active_node.file_system_state_actual == expected_state diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index efca7b0b..21e4857f 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -7,79 +7,78 @@ from primaite.environment.observations import ( NodeStatuses, ObservationsHandler, ) -from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config -@pytest.fixture -def env(request): - """Build Primaite environment for integration tests of observation space.""" - marker = request.node.get_closest_marker("env_config_paths") - training_config_path = marker.args[0]["training_config_path"] - lay_down_config_path = marker.args[0]["lay_down_config_path"] - env = _get_primaite_env_from_config( - training_config_path=training_config_path, - lay_down_config_path=lay_down_config_path, - ) - yield env - - -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_without_obs.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) -def test_default_obs_space(env: Primaite): +def test_default_obs_space(temp_primaite_session): """Create environment with no obs space defined in config and check that the default obs space was created.""" - env.update_environent_obs() + with temp_primaite_session as session: + session.env.update_environent_obs() - components = env.obs_handler.registered_obs_components + components = session.env.obs_handler.registered_obs_components - assert len(components) == 1 - assert isinstance(components[0], NodeLinkTable) + assert len(components) == 1 + assert isinstance(components[0], NodeLinkTable) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_without_obs.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) -def test_registering_components(env: Primaite): +def test_registering_components(temp_primaite_session): """Test regitering and deregistering a component.""" - handler = ObservationsHandler() - component = NodeStatuses(env) - handler.register(component) - assert component in handler.registered_obs_components - handler.deregister(component) - assert component not in handler.registered_obs_components + with temp_primaite_session as session: + env = session.env + handler = ObservationsHandler() + component = NodeStatuses(env) + handler.register(component) + assert component in handler.registered_obs_components + handler.deregister(component) + assert component not in handler.registered_obs_components -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_NODE_LINK_TABLE.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_LINK_TABLE.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestNodeLinkTable: """Test the NodeLinkTable observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with box observation space.""" - env.update_environent_obs() + with temp_primaite_session as session: + env = session.env + env.update_environent_obs() - # we have three nodes and two links, with two service - # therefore the box observation space will have: - # * 5 rows (3 nodes + 2 links) - # * 6 columns (four fixed and two for the services) - assert env.env_obs.shape == (5, 6) + # we have three nodes and two links, with two service + # therefore the box observation space will have: + # * 5 rows (3 nodes + 2 links) + # * 6 columns (four fixed and two for the services) + assert env.env_obs.shape == (5, 6) - def test_value(self, env: Primaite): + def test_value(self, temp_primaite_session): """Test that the observation is generated correctly. The laydown has: @@ -125,36 +124,45 @@ class TestNodeLinkTable: * 999 (999 traffic service1) * 0 (no traffic for service2) """ - # act = np.asarray([0,]) - obs, reward, done, info = env.step(0) # apply the 'do nothing' action + with temp_primaite_session as session: + env = session.env + # act = np.asarray([0,]) + obs, reward, done, info = env.step( + 0 + ) # apply the 'do nothing' action - assert np.array_equal( - obs, - [ - [1, 1, 3, 1, 1, 1], - [2, 1, 1, 1, 1, 4], - [3, 1, 1, 1, 0, 0], - [4, 0, 0, 0, 999, 0], - [5, 0, 0, 0, 999, 0], - ], - ) + assert np.array_equal( + obs, + [ + [1, 1, 3, 1, 1, 1], + [2, 1, 1, 1, 1, 4], + [3, 1, 1, 1, 0, 0], + [4, 0, 0, 0, 999, 0], + [5, 0, 0, 0, 999, 0], + ], + ) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_NODE_STATUSES.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestNodeStatuses: """Test the NodeStatuses observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with NodeStatuses as the only component.""" - assert env.env_obs.shape == (15,) + with temp_primaite_session as session: + env = session.env + assert env.env_obs.shape == (15,) - def test_values(self, env: Primaite): + def test_values(self, temp_primaite_session): """Test that the hardware and software states are encoded correctly. The laydown has: @@ -181,28 +189,38 @@ class TestNodeStatuses: * service 1 = n/a (0) * service 2 = n/a (0) """ - obs, _, _, _ = env.step(0) # apply the 'do nothing' action - assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) + with temp_primaite_session as session: + env = session.env + obs, _, _, _ = env.step(0) # apply the 'do nothing' action + assert np.array_equal( + obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0] + ) -@pytest.mark.env_config_paths( - dict( - training_config_path=TEST_CONFIG_ROOT - / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT + / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", + TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ] + ], + indirect=True, ) class TestLinkTrafficLevels: """Test the LinkTrafficLevels observation component (in isolation).""" - def test_obs_shape(self, env: Primaite): + def test_obs_shape(self, temp_primaite_session): """Try creating env with MultiDiscrete observation space.""" - env.update_environent_obs() + with temp_primaite_session as session: + env = session.env + env.update_environent_obs() - # we have two links and two services, so the shape should be 2 * 2 - assert env.env_obs.shape == (2 * 2,) + # we have two links and two services, so the shape should be 2 * 2 + assert env.env_obs.shape == (2 * 2,) - def test_values(self, env: Primaite): + def test_values(self, temp_primaite_session): """Test that traffic values are encoded correctly. The laydown has: @@ -212,12 +230,14 @@ class TestLinkTrafficLevels: * an IER trying to send 999 bits of data over both links the whole time (via the first service) * link bandwidth of 1000, therefore the utilisation is 99.9% """ - obs, reward, done, info = env.step(0) - obs, reward, done, info = env.step(0) + with temp_primaite_session as session: + env = session.env + obs, reward, done, info = env.step(0) + obs, reward, done, info = env.step(0) - # the observation space has combine_service_traffic set to False, so the space has this format: - # [link1_service1, link1_service2, link2_service1, link2_service2] - # we send 999 bits of data via link1 and link2 on service 1. - # therefore the first and third elements should be 6 and all others 0 - # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) - assert np.array_equal(obs, [6, 0, 6, 0]) + # the observation space has combine_service_traffic set to False, so the space has this format: + # [link1_service1, link1_service2, link2_service1, link2_service2] + # we send 999 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 6 and all others 0 + # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) + assert np.array_equal(obs, [6, 0, 6, 0]) diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py new file mode 100644 index 00000000..8c8d2b80 --- /dev/null +++ b/tests/test_primaite_session.py @@ -0,0 +1,61 @@ +import os + +import pytest + +from primaite import getLogger +from primaite.config.lay_down_config import dos_very_basic_config_path +from primaite.config.training_config import main_training_config_path + +_LOGGER = getLogger(__name__) + + +@pytest.mark.parametrize( + "temp_primaite_session", + [[main_training_config_path(), dos_very_basic_config_path()]], + indirect=True, +) +def test_primaite_session(temp_primaite_session): + """Tests the PrimaiteSession class and its outputs.""" + with temp_primaite_session as session: + session_path = session.session_path + assert session_path.exists() + session.learn() + # Learning outputs are saved in session.learning_path + session.evaluate() + # Evaluation outputs are saved in session.evaluation_path + + # If you need to inspect any session outputs, it must be done inside + # the context manager + + # Check that the metadata json file exists + assert (session_path / "session_metadata.json").exists() + + # Check that the network png file exists + assert (session_path / f"network_{session.timestamp_str}.png").exists() + + # Check that both the transactions and av reward csv files exist + for file in session.learning_path.iterdir(): + if file.suffix == ".csv": + assert ( + "all_transactions" in file.name + or "average_reward_per_episode" in file.name + ) + + # Check that both the transactions and av reward csv files exist + for file in session.evaluation_path.iterdir(): + if file.suffix == ".csv": + assert ( + "all_transactions" in file.name + or "average_reward_per_episode" in file.name + ) + + _LOGGER.debug("Inspecting files in temp session path...") + for dir_path, dir_names, file_names in os.walk(session_path): + for file in file_names: + path = os.path.join(dir_path, file) + file_str = path.split(str(session_path))[-1] + _LOGGER.debug(f" {file_str}") + + # Now that we've exited the context manager, the session.session_path + # directory and its contents are deleted + assert not session_path.exists() diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index abe8115c..e7312777 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -18,7 +18,9 @@ from primaite.nodes.service_node import ServiceNode "starting_operating_state, expected_operating_state", [(HardwareState.RESETTING, HardwareState.ON)], ) -def test_node_resets_correctly(starting_operating_state, expected_operating_state): +def test_node_resets_correctly( + starting_operating_state, expected_operating_state +): """Tests that a node resets correctly.""" active_node = ActiveNode( node_id="0", diff --git a/tests/test_reward.py b/tests/test_reward.py index c3fcdfc4..95603b54 100644 --- a/tests/test_reward.py +++ b/tests/test_reward.py @@ -1,26 +1,33 @@ +import pytest + from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config -def test_rewards_are_being_penalised_at_each_step_function(): +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_rewards_are_being_penalised_at_each_step_function( + temp_primaite_session, +): """ Test that hardware state is penalised at each step. When the initial state is OFF compared to reference state which is ON. - """ - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_lay_down_config.yaml", - ) - """ - On different steps (of the 13 in total) these are the following rewards for config_6 which are activated: + On different steps (of the 13 in total) these are the following rewards + for config_6 which are activated: File System State: goodShouldBeCorrupt = 5 (between Steps 1 & 3) Hardware State: onShouldBeOff = -2 (between Steps 4 & 6) Service State: goodShouldBeCompromised = 5 (between Steps 7 & 9) - Software State (Software State): goodShouldBeCompromised = 5 (between Steps 10 & 12) + Software State (Software State): goodShouldBeCompromised = 5 (between + Steps 10 & 12) Total Reward: -2 - 2 + 5 + 5 + 5 + 5 + 5 + 5 = 26 Step Count: 13 @@ -28,5 +35,8 @@ def test_rewards_are_being_penalised_at_each_step_function(): For the 4 steps where this occurs the average reward is: Average Reward: 2 (26 / 13) """ - print("average reward", env.average_reward) - assert env.average_reward == -8.0 + with temp_primaite_session as session: + session.evaluate() + session.close() + ev_rewards = session.eval_av_reward_per_episode_csv() + assert ev_rewards[1] == -8.0 diff --git a/tests/test_service_node.py b/tests/test_service_node.py index 4383fc1b..9e760b23 100644 --- a/tests/test_service_node.py +++ b/tests/test_service_node.py @@ -45,7 +45,9 @@ def test_service_state_change(operating_state, expected_state): (HardwareState.ON, SoftwareState.OVERWHELMED), ], ) -def test_service_state_change_if_not_comprised(operating_state, expected_state): +def test_service_state_change_if_not_comprised( + operating_state, expected_state +): """ Test that a node cannot change the state of a running service. @@ -65,6 +67,8 @@ def test_service_state_change_if_not_comprised(operating_state, expected_state): service = Service("TCP", 80, SoftwareState.GOOD) service_node.add_service(service) - service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED) + service_node.set_service_state_if_not_compromised( + "TCP", SoftwareState.OVERWHELMED + ) assert service_node.get_service_state("TCP") == expected_state diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 8ff43fe6..1cf63cde 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -1,9 +1,10 @@ import time +import pytest + from primaite.common.enums import HardwareState from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_primaite_env_from_config def run_generic_set_actions(env: Primaite): @@ -44,59 +45,72 @@ def run_generic_set_actions(env: Primaite): # env.close() -def test_single_action_space_is_valid(): - """Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations.""" - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "single_action_space_lay_down_config.yaml", - ) +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "single_action_space_main_config.yaml", + TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_single_action_space_is_valid(temp_primaite_session): + """Test single action space is valid.""" + with temp_primaite_session as session: + env = session.env - run_generic_set_actions(env) - - # Retrieve the action space dictionary values from environment - env_action_space_dict = env.action_dict.values() - # Flags to check the conditions of the action space - contains_acl_actions = False - contains_node_actions = False - both_action_spaces = False - # Loop through each element of the list (which is every value from the dictionary) - for dict_item in env_action_space_dict: - # Node action detected - if len(dict_item) == 4: - contains_node_actions = True - # Link action detected - elif len(dict_item) == 6: - contains_acl_actions = True - # If both are there then the ANY action type is working - if contains_node_actions and contains_acl_actions: - both_action_spaces = True - # Check condition should be True - assert both_action_spaces + run_generic_set_actions(env) + # Retrieve the action space dictionary values from environment + env_action_space_dict = env.action_dict.values() + # Flags to check the conditions of the action space + contains_acl_actions = False + contains_node_actions = False + both_action_spaces = False + # Loop through each element of the list (which is every value from the dictionary) + for dict_item in env_action_space_dict: + # Node action detected + if len(dict_item) == 4: + contains_node_actions = True + # Link action detected + elif len(dict_item) == 6: + contains_acl_actions = True + # If both are there then the ANY action type is working + if contains_node_actions and contains_acl_actions: + both_action_spaces = True + # Check condition should be True + assert both_action_spaces -def test_agent_is_executing_actions_from_both_spaces(): +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT + / "single_action_space_fixed_blue_actions_main_config.yaml", + TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", + ] + ], + indirect=True, +) +def test_agent_is_executing_actions_from_both_spaces(temp_primaite_session): """Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL).""" - env = _get_primaite_env_from_config( - training_config_path=TEST_CONFIG_ROOT - / "single_action_space_fixed_blue_actions_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT - / "single_action_space_lay_down_config.yaml", - ) - # Run environment with specified fixed blue agent actions only - run_generic_set_actions(env) - # Retrieve hardware state of computer_1 node in laydown config - # Agent turned this off in Step 5 - computer_node_hardware_state = env.nodes["1"].hardware_state - # Retrieve the Access Control List object stored by the environment at the end of the episode - access_control_list = env.acl - # Use the Access Control List object acl object attribute to get dictionary - # Use dictionary.values() to get total list of all items in the dictionary - acl_rules_list = access_control_list.acl.values() - # Length of this list tells you how many items are in the dictionary - # This number is the frequency of Access Control Rules in the environment - # In the scenario, we specified that the agent should create only 1 acl rule - num_of_rules = len(acl_rules_list) - # Therefore these statements below MUST be true - assert computer_node_hardware_state == HardwareState.OFF - assert num_of_rules == 1 + with temp_primaite_session as session: + env = session.env + # Run environment with specified fixed blue agent actions only + run_generic_set_actions(env) + # Retrieve hardware state of computer_1 node in laydown config + # Agent turned this off in Step 5 + computer_node_hardware_state = env.nodes["1"].hardware_state + # Retrieve the Access Control List object stored by the environment at the end of the episode + access_control_list = env.acl + # Use the Access Control List object acl object attribute to get dictionary + # Use dictionary.values() to get total list of all items in the dictionary + acl_rules_list = access_control_list.acl.values() + # Length of this list tells you how many items are in the dictionary + # This number is the frequency of Access Control Rules in the environment + # In the scenario, we specified that the agent should create only 1 acl rule + num_of_rules = len(acl_rules_list) + # Therefore these statements below MUST be true + assert computer_node_hardware_state == HardwareState.OFF + assert num_of_rules == 1 diff --git a/tests/test_training_config.py b/tests/test_training_config.py index 02e90d30..88bc802b 100644 --- a/tests/test_training_config.py +++ b/tests/test_training_config.py @@ -16,7 +16,9 @@ def test_legacy_lay_down_config_yaml_conversion(): with open(new_path, "r") as file: new_dict = yaml.safe_load(file) - converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict) + converted_dict = training_config.convert_legacy_training_config_dict( + legacy_dict + ) for key, value in new_dict.items(): assert converted_dict[key] == value From c3c45125448c905ddc7fffa4921b47187f3bea76 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 09:54:34 +0100 Subject: [PATCH 19/43] Remove temporary file --- scratch.py | 6 ------ 1 file changed, 6 deletions(-) delete mode 100644 scratch.py diff --git a/scratch.py b/scratch.py deleted file mode 100644 index 6bab60c1..00000000 --- a/scratch.py +++ /dev/null @@ -1,6 +0,0 @@ -from primaite.main import run - -run( - "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/training/training_config_main.yaml", - "/home/cade/repos/PrimAITE/src/primaite/config/_package_data/lay_down/lay_down_config_5_data_manipulation.yaml", -) From d5402cdce8c2b9001869839e78e8372df1f2324e Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 30 Jun 2023 10:24:59 +0100 Subject: [PATCH 20/43] #917 - Added tensorflow to main deps for RLlib. - Dropped support for Python 3.11 due to not supported on Ray RLlib. - Made release pipeline only run once as we're now no longer using pure path wheels. --- .azure/artifact-release-pipeline.yaml | 32 ++----------------- .azure/azure-ci-build-pipeline.yaml | 2 -- pyproject.toml | 7 ++-- .../training/training_config_main.yaml | 2 +- 4 files changed, 6 insertions(+), 37 deletions(-) diff --git a/.azure/artifact-release-pipeline.yaml b/.azure/artifact-release-pipeline.yaml index ca8f5b60..47e9aacc 100644 --- a/.azure/artifact-release-pipeline.yaml +++ b/.azure/artifact-release-pipeline.yaml @@ -1,38 +1,12 @@ trigger: - main +pool: + vmImage: ubuntu-latest strategy: matrix: - Ubuntu2004Python38: - python.version: '3.8' - imageName: 'ubuntu-20.04' - Ubuntu2004Python39: - python.version: '3.9' - imageName: 'ubuntu-20.04' - Ubuntu2004Python310: + Python310: python.version: '3.10' - imageName: 'ubuntu-20.04' - WindowsPython38: - python.version: '3.8' - imageName: 'windows-latest' - WindowsPython39: - python.version: '3.9' - imageName: 'windows-latest' - WindowsPython310: - python.version: '3.10' - imageName: 'windows-latest' - MacPython38: - python.version: '3.8' - imageName: 'macOS-latest' - MacPython39: - python.version: '3.9' - imageName: 'macOS-latest' - MacPython310: - python.version: '3.10' - imageName: 'macOS-latest' - -pool: - vmImage: $(imageName) steps: - task: UsePythonVersion@0 diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 691f71e9..4c15daf5 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -16,8 +16,6 @@ strategy: python.version: '3.9' Python310: python.version: '3.10' - Python311: - python.version: '3.11' steps: - task: UsePythonVersion@0 diff --git a/pyproject.toml b/pyproject.toml index 09b60777..41ed6516 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "primaite" description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme." authors = [{name="QinetiQ Training and Simulation Ltd"}] license = {text = "GFX"} -requires-python = ">=3.8" +requires-python = ">=3.8, <3.11" dynamic = ["version", "readme"] classifiers = [ "License :: GFX", @@ -20,7 +20,6 @@ classifiers = [ "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3 :: Only", ] @@ -36,6 +35,7 @@ dependencies = [ "PyYAML==6.0", "ray[rllib]==2.2.0", "stable-baselines3==1.6.2", + "tensorflow==2.12.0", "typer[all]==0.9.0" ] @@ -65,9 +65,6 @@ dev = [ "sphinx-copybutton==0.5.2", "wheel==0.38.4" ] -tensorflow = [ - "tensorflow==2.12.0", -] [project.scripts] primaite = "primaite.cli:app" diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index cc5d4955..57793058 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -13,7 +13,7 @@ agent_framework: SB3 # "TF" (Tensorflow) # TF2 (Tensorflow 2.X) # TORCH (PyTorch) -deep_learning_framework: TORCH +deep_learning_framework: TF2 # Sets which Agent class will be used. # Options are: From 3e691b4f4611309e17049d7e7f7f41b72fc312e6 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Fri, 30 Jun 2023 10:37:23 +0100 Subject: [PATCH 21/43] #1522: remove numpy randomisation + added random red agent config --- .../training/training_config_main.yaml | 5 + .../training_config_random_red_agent.yaml | 99 +++++++++++++++++++ src/primaite/environment/primaite_env.py | 39 +++----- 3 files changed, 118 insertions(+), 25 deletions(-) create mode 100644 src/primaite/config/_package_data/training/training_config_random_red_agent.yaml diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index d01f51f3..3fe668e2 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -6,6 +6,11 @@ # "STABLE_BASELINES3_A2C" # "GENERIC" agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +red_agent_identifier: "NONE" + # Sets How the Action Space is defined: # "NODE" # "ACL" diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml new file mode 100644 index 00000000..9382a2b5 --- /dev/null +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -0,0 +1,99 @@ +# Main Config File + +# Generic config values +# Choose one of these (dependent on Agent being trained) +# "STABLE_BASELINES3_PPO" +# "STABLE_BASELINES3_A2C" +# "GENERIC" +agent_identifier: STABLE_BASELINES3_A2C + +# RED AGENT IDENTIFIER +# RANDOM or NONE +red_agent_identifier: "RANDOM" + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 10 +# Number of time_steps per episode +num_steps: 256 +# Time delay between steps (for generic agents) +time_delay: 10 +# Type of session to be run (TRAINING or EVALUATION) +session_type: TRAINING +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 9ac3d8e6..e592e21f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -5,6 +5,7 @@ import csv import logging from datetime import datetime from pathlib import Path +from random import randint, choice, uniform, sample from typing import Dict, Tuple, Union import networkx as nx @@ -276,8 +277,8 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - # if self.training_config.red_agent_identifier == "RANDOM": - # self.create_random_red_agent() + if self.training_config.red_agent_identifier == "RANDOM": + self.create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -1249,13 +1250,13 @@ class Primaite(Env): computers ) # only computers can become compromised # random select between 1 and max_num_nodes_compromised - num_nodes_to_compromise = np.random.randint(1, max_num_nodes_compromised) + num_nodes_to_compromise = randint(1, max_num_nodes_compromised) # Decide which of the nodes to compromise - nodes_to_be_compromised = np.random.choice(computers, num_nodes_to_compromise) + nodes_to_be_compromised = sample(computers, num_nodes_to_compromise) # choose a random compromise node to be source of attacks - source_node = np.random.choice(nodes_to_be_compromised, 1)[0] + source_node = choice(nodes_to_be_compromised) # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( @@ -1270,14 +1271,14 @@ class Primaite(Env): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = np.random.randint( + _start_step = randint( 2, max_step_compromised + 1 ) # step compromised - pol_service_name = np.random.choice( + pol_service_name = choice( list(node.services.keys()) ) - source_node_service = np.random.choice( + source_node_service = choice( list(source_node.services.values()) ) @@ -1301,13 +1302,13 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = np.random.randint( + ier_start_step = randint( _start_step + 2, int(self.episode_steps * 0.8) ) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith - ier_load = np.random.uniform(low=0.4, high=0.8) * np.random.choice( + ier_load = uniform(0.4, 0.8) * choice( bandwidths ) ier_protocol = pol_service_name # Same protocol as compromised node @@ -1328,7 +1329,7 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.get_ip_address(), + node.ip_address, server.ip_address, ier_service, ier_port, @@ -1337,7 +1338,7 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: # If still none found choose from all servers possible_ier_destinations = [server.node_id for server in servers] - ier_dest = np.random.choice(possible_ier_destinations) + ier_dest = choice(possible_ier_destinations) self.red_iers[ier_id] = IER( ier_id, ier_start_step, @@ -1354,22 +1355,10 @@ class Primaite(Env): overwhelm_pol.id = str(uuid.uuid4()) overwhelm_pol.end_step = self.episode_steps - # 3: Make sure the targetted node can be set to overwhelmed - with node pol # # TODO remove duplicate red pol for same targetted service - must take into account start step - # + o_pol_id = str(uuid.uuid4()) - # o_pol_start_step = ier_start_step # Can become compromised the same step attack is launched - # o_pol_end_step = ( - # self.episode_steps - # ) # Can become compromised at any timestep after start - # o_pol_node_id = ier_dest # Node effected is the one targetted by the IER - # o_pol_node_type = NodePOLType["SERVICE"] # Always targets service nodes - # o_pol_service_name = ( - # ier_protocol # Same protocol/service as the IER uses to attack - # ) - # o_pol_new_state = SoftwareState["OVERWHELMED"] - # o_pol_entry_node = False # Assumes servers are not entry nodes o_red_pol = NodeStateInstructionRed( _id=o_pol_id, _start_step=ier_start_step, From 2a8d28cba68190f3ec528812adaaa09318395f69 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 10:41:56 +0100 Subject: [PATCH 22/43] Remove redundant cols from transactions --- src/primaite/environment/observations.py | 2 +- src/primaite/environment/primaite_env.py | 4 +--- src/primaite/transactions/transaction.py | 13 ++----------- src/primaite/transactions/transactions_to_file.py | 9 ++++----- 4 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 023c5f30..fcd52559 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -168,7 +168,7 @@ class NodeLinkTable(AbstractObservationComponent): f"link_{link_id}_n/a", ] for j, serv in enumerate(self.env.services_list): - link_labels.append(f"node_{node_id}_service_{serv}_load") + link_labels.append(f"link_{link_id}_service_{serv}_load") structure.extend(link_labels) return structure diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e56abf9d..2418cac0 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -318,7 +318,7 @@ class Primaite(Env): datetime.now(), self.agent_identifier, self.episode_count, self.step_count ) # Load the initial observation space into the transaction - transaction.set_obs_space_pre(self.obs_handler._flat_observation) + transaction.set_obs_space(self.obs_handler._flat_observation) # Load the action space into the transaction transaction.set_action_space(copy.deepcopy(action)) @@ -400,8 +400,6 @@ class Primaite(Env): # 7. Update env_obs self.update_environent_obs() - # Load the new observation space into the transaction - transaction.set_obs_space_post(self.obs_handler._flat_observation) # 8. Add the transaction to the list of transactions self.transaction_list.append(copy.deepcopy(transaction)) diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index a4ce48e3..39236217 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -20,23 +20,14 @@ class Transaction(object): self.episode_number = _episode_number self.step_number = _step_number - def set_obs_space_pre(self, _obs_space_pre): + def set_obs_space(self, _obs_space): """ Sets the observation space (pre). Args: _obs_space_pre: The observation space before any actions are taken """ - self.obs_space_pre = _obs_space_pre - - def set_obs_space_post(self, _obs_space_post): - """ - Sets the observation space (post). - - Args: - _obs_space_post: The observation space after any actions are taken - """ - self.obs_space_post = _obs_space_post + self.obs_space = _obs_space def set_reward(self, _reward): """ diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py index b2a4d40d..4e364f0b 100644 --- a/src/primaite/transactions/transactions_to_file.py +++ b/src/primaite/transactions/transactions_to_file.py @@ -58,12 +58,12 @@ def write_transaction_to_file( action_header.append("AS_" + str(x)) # Create the observation space headers array - obs_header_initial = [f"pre_{o}" for o in obs_space_description] - obs_header_new = [f"post_{o}" for o in obs_space_description] + # obs_header_initial = [f"pre_{o}" for o in obs_space_description] + # obs_header_new = [f"post_{o}" for o in obs_space_description] # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_header_initial + obs_header_new + header = header + action_header + obs_space_description try: filename = session_path / f"all_transactions_{timestamp_str}.csv" @@ -82,8 +82,7 @@ def write_transaction_to_file( csv_data = ( csv_data + turn_action_space_to_array(transaction.action_space) - + transaction.obs_space_pre.tolist() - + transaction.obs_space_post.tolist() + + transaction.obs_space.tolist() ) csv_writer.writerow(csv_data) From 32d5889b11e405a0bd5e4ed72eb9727f37aa4652 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 10:44:04 +0100 Subject: [PATCH 23/43] Update docs --- docs/source/primaite_session.rst | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 4f639f11..a59b2361 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -78,10 +78,9 @@ PrimAITE automatically creates two sets of results from each session: * Timestamp * Episode number * Step number - * Initial observation space (before red and blue agent actions have been taken). Individual elements of the observation space are presented in the format OSI_X_Y - * Resulting observation space (after the red and blue agent actions have been taken) Individual elements of the observation space are presented in the format OSN_X_Y + * Initial observation space (what the blue agent observed when it decided its action) * Reward value - * Action space (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X + * Action taken (as presented by the blue agent on this step). Individual elements of the action space are presented in the format AS_X **Diagrams** From 203cc98494787805dfe6cafd9f14843363f5be58 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 30 Jun 2023 11:40:26 +0100 Subject: [PATCH 24/43] #917 - Fixed primaite_config.yaml issue in cli.py - Added kaleido to deps in pyproject.toml --- pyproject.toml | 1 + src/primaite/cli.py | 15 ++++++--------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 41ed6516..b63e83fc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ classifiers = [ dependencies = [ "gym==0.21.0", "jupyterlab==3.6.1", + "kaleido==0.2.1", "matplotlib==3.7.1", "networkx==3.1", "numpy==1.23.5", diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 10e23bfc..0431174f 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -124,15 +124,13 @@ def setup(overwrite_existing: bool = True): app_dirs = PlatformDirs(appname="primaite") app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) user_config_path = app_dirs.user_config_path / "primaite_config.yaml" - build_config = overwrite_existing or (not user_config_path.exists()) - if build_config: - pkg_config_path = Path( - pkg_resources.resource_filename( - "primaite", "setup/_package_data/primaite_config.yaml" - ) + pkg_config_path = Path( + pkg_resources.resource_filename( + "primaite", "setup/_package_data/primaite_config.yaml" ) + ) - shutil.copy2(pkg_config_path, user_config_path) + shutil.copy2(pkg_config_path, user_config_path) from primaite import getLogger from primaite.setup import ( @@ -146,8 +144,7 @@ def setup(overwrite_existing: bool = True): _LOGGER.info("Performing the PrimAITE first-time setup...") - if build_config: - _LOGGER.info("Building primaite_config.yaml...") + _LOGGER.info("Building primaite_config.yaml...") _LOGGER.info("Building the PrimAITE app directories...") setup_app_dirs.run() From 975ebd6de2d43aa0ad65d1d69cf45c37c81aa609 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 13:16:30 +0100 Subject: [PATCH 25/43] revert unnecessary changes. --- .../_package_data/training/training_config_main.yaml | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index a679400c..ac63c667 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -5,7 +5,7 @@ # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" # "GENERIC" -agent_identifier: STABLE_BASELINES3_A2C +agent_identifier: STABLE_BASELINES3_PPO # Sets How the Action Space is defined: # "NODE" # "ACL" @@ -16,12 +16,14 @@ observation_space: # flatten: true components: - name: NODE_LINK_TABLE + # - name: NODE_STATUSES + # - name: LINK_TRAFFIC_LEVELS # Number of episodes to run per session -num_episodes: 1000 +num_episodes: 10 # Number of time_steps per episode num_steps: 256 # Time delay between steps (for generic agents) -time_delay: 0 +time_delay: 10 # Type of session to be run (TRAINING or EVALUATION) session_type: TRAINING # Determine whether to load an agent from file From 605ff98a24eaca0e34b3d4c24e0dc8b4fc42761b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 30 Jun 2023 15:43:15 +0100 Subject: [PATCH 26/43] Fix flattening when there are no components. --- src/primaite/environment/observations.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index fcd52559..b19bd29f 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -450,7 +450,10 @@ class ObservationsHandler: self._space = component_spaces[0] else: self._space = spaces.Tuple(component_spaces) - self._flat_space = spaces.flatten_space(self._space) + if len(component_spaces) > 0: + self._flat_space = spaces.flatten_space(self._space) + else: + self._flat_space = spaces.Box(0, 1, (0,)) @property def space(self): From 27ca53878af779f3df190998b46cf6e060c8cbcb Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 30 Jun 2023 16:52:57 +0100 Subject: [PATCH 27/43] #917 - Fixed the RLlib integration - Dropped support for overriding the num_episodes and num_steps at the agent level. It's just not needed and will add complexity when overriding and writing output files. --- .pre-commit-config.yaml | 2 +- pyproject.toml | 4 +- src/primaite/__init__.py | 27 +--- src/primaite/acl/access_control_list.py | 55 ++----- src/primaite/agents/agent.py | 71 ++------- src/primaite/agents/hardcoded_acl.py | 66 +++------ src/primaite/agents/hardcoded_node.py | 14 +- src/primaite/agents/rllib.py | 74 ++-------- src/primaite/agents/sb3.py | 66 +++------ src/primaite/agents/simple.py | 6 +- src/primaite/agents/utils.py | 43 ++---- src/primaite/cli.py | 29 +--- src/primaite/config/lay_down_config.py | 8 +- src/primaite/config/training_config.py | 25 +--- src/primaite/environment/observations.py | 38 ++--- src/primaite/environment/primaite_env.py | 136 ++++++------------ src/primaite/environment/reward.py | 41 ++---- src/primaite/links/link.py | 4 +- src/primaite/main.py | 8 +- src/primaite/nodes/active_node.py | 49 ++----- .../nodes/node_state_instruction_green.py | 4 +- .../nodes/node_state_instruction_red.py | 4 +- src/primaite/nodes/passive_node.py | 4 +- src/primaite/nodes/service_node.py | 24 +--- src/primaite/pol/green_pol.py | 66 ++------- src/primaite/pol/red_agent_pol.py | 63 ++------ src/primaite/primaite_session.py | 119 ++++----------- src/primaite/setup/reset_demo_notebooks.py | 12 +- src/primaite/setup/reset_example_configs.py | 8 +- src/primaite/transactions/transaction.py | 8 +- src/primaite/utils/session_output_writer.py | 4 +- .../legacy_training_config.yaml | 0 .../new_training_config.yaml | 0 tests/conftest.py | 10 +- tests/mock_and_patch/get_session_path_mock.py | 4 +- tests/test_acl.py | 4 +- tests/test_active_node.py | 12 +- tests/test_observation_space.py | 18 +-- tests/test_primaite_session.py | 10 +- tests/test_resetting_node.py | 16 +-- tests/test_service_node.py | 8 +- tests/test_single_action_space.py | 4 +- tests/test_training_config.py | 12 +- 43 files changed, 284 insertions(+), 896 deletions(-) rename tests/config/{legacy => legacy_conversion}/legacy_training_config.yaml (100%) rename tests/config/{legacy => legacy_conversion}/new_training_config.yaml (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 26cd5697..6e435bee 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: rev: 23.1.0 hooks: - id: black - args: [ "--line-length=79" ] + args: [ "--line-length=120" ] additional_dependencies: - jupyter - repo: http://github.com/pycqa/isort diff --git a/pyproject.toml b/pyproject.toml index b63e83fc..dc04f609 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,9 +72,9 @@ primaite = "primaite.cli:app" [tool.isort] profile = "black" -line_length = 79 +line_length = 120 force_sort_within_sections = "False" order_by_type = "False" [tool.black] -line-length = 79 +line-length = 120 diff --git a/src/primaite/__init__.py b/src/primaite/__init__.py index e753b4ef..030860d8 100644 --- a/src/primaite/__init__.py +++ b/src/primaite/__init__.py @@ -19,11 +19,7 @@ _PLATFORM_DIRS: Final[PlatformDirs] = PlatformDirs(appname="primaite") def _get_primaite_config(): config_path = _PLATFORM_DIRS.user_config_path / "primaite_config.yaml" if not config_path.exists(): - config_path = Path( - pkg_resources.resource_filename( - "primaite", "setup/_package_data/primaite_config.yaml" - ) - ) + config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) with open(config_path, "r") as file: primaite_config = yaml.safe_load(file) log_level_map = { @@ -34,9 +30,7 @@ def _get_primaite_config(): "ERROR": logging.ERROR, "CRITICAL": logging.CRITICAL, } - primaite_config["log_level"] = log_level_map[ - primaite_config["logging"]["log_level"] - ] + primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]] return primaite_config @@ -82,14 +76,9 @@ class _LevelFormatter(Formatter): super().__init__() if "fmt" in kwargs: - raise ValueError( - "Format string must be passed to level-surrogate formatters, " - "not this one" - ) + raise ValueError("Format string must be passed to level-surrogate formatters, " "not this one") - self.formats = sorted( - (level, Formatter(fmt, **kwargs)) for level, fmt in formats.items() - ) + self.formats = sorted((level, Formatter(fmt, **kwargs)) for level, fmt in formats.items()) def format(self, record: LogRecord) -> str: """Overrides ``Formatter.format``.""" @@ -110,13 +99,9 @@ _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( { logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], - logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"][ - "WARNING" - ], + logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], - logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"][ - "CRITICAL" - ], + logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"], } ) diff --git a/src/primaite/acl/access_control_list.py b/src/primaite/acl/access_control_list.py index a147d963..3b0e9234 100644 --- a/src/primaite/acl/access_control_list.py +++ b/src/primaite/acl/access_control_list.py @@ -10,9 +10,7 @@ class AccessControlList: def __init__(self): """Init.""" - self.acl: Dict[ - str, AccessControlList - ] = {} # A dictionary of ACL Rules + self.acl: Dict[str, AccessControlList] = {} # A dictionary of ACL Rules def check_address_match(self, _rule, _source_ip_address, _dest_ip_address): """ @@ -27,29 +25,16 @@ class AccessControlList: True if match; False otherwise. """ if ( - ( - _rule.get_source_ip() == _source_ip_address - and _rule.get_dest_ip() == _dest_ip_address - ) - or ( - _rule.get_source_ip() == "ANY" - and _rule.get_dest_ip() == _dest_ip_address - ) - or ( - _rule.get_source_ip() == _source_ip_address - and _rule.get_dest_ip() == "ANY" - ) - or ( - _rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY" - ) + (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == _dest_ip_address) + or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == _dest_ip_address) + or (_rule.get_source_ip() == _source_ip_address and _rule.get_dest_ip() == "ANY") + or (_rule.get_source_ip() == "ANY" and _rule.get_dest_ip() == "ANY") ): return True else: return False - def is_blocked( - self, _source_ip_address, _dest_ip_address, _protocol, _port - ): + def is_blocked(self, _source_ip_address, _dest_ip_address, _protocol, _port): """ Checks for rules that block a protocol / port. @@ -63,15 +48,9 @@ class AccessControlList: Indicates block if all conditions are satisfied. """ for rule_key, rule_value in self.acl.items(): - if self.check_address_match( - rule_value, _source_ip_address, _dest_ip_address - ): - if ( - rule_value.get_protocol() == _protocol - or rule_value.get_protocol() == "ANY" - ) and ( - str(rule_value.get_port()) == str(_port) - or rule_value.get_port() == "ANY" + if self.check_address_match(rule_value, _source_ip_address, _dest_ip_address): + if (rule_value.get_protocol() == _protocol or rule_value.get_protocol() == "ANY") and ( + str(rule_value.get_port()) == str(_port) or rule_value.get_port() == "ANY" ): # There's a matching rule. Get the permission if rule_value.get_permission() == "DENY": @@ -93,9 +72,7 @@ class AccessControlList: _protocol: the protocol _port: the port """ - new_rule = ACLRule( - _permission, _source_ip, _dest_ip, _protocol, str(_port) - ) + new_rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(new_rule) self.acl[hash_value] = new_rule @@ -110,9 +87,7 @@ class AccessControlList: _protocol: the protocol _port: the port """ - rule = ACLRule( - _permission, _source_ip, _dest_ip, _protocol, str(_port) - ) + rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(rule) # There will not always be something 'popable' since the agent will be trying random things try: @@ -124,9 +99,7 @@ class AccessControlList: """Removes all rules.""" self.acl.clear() - def get_dictionary_hash( - self, _permission, _source_ip, _dest_ip, _protocol, _port - ): + def get_dictionary_hash(self, _permission, _source_ip, _dest_ip, _protocol, _port): """ Produces a hash value for a rule. @@ -140,8 +113,6 @@ class AccessControlList: Returns: Hash value based on rule parameters. """ - rule = ACLRule( - _permission, _source_ip, _dest_ip, _protocol, str(_port) - ) + rule = ACLRule(_permission, _source_ip, _dest_ip, _protocol, str(_port)) hash_value = hash(rule) return hash_value diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index c76583c0..90eb2b66 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -5,7 +5,7 @@ import time from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path -from typing import Dict, Final, Optional, Union +from typing import Dict, Final, Union from uuid import uuid4 import yaml @@ -51,16 +51,12 @@ class AgentSessionABC(ABC): if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path - self._training_config: Final[TrainingConfig] = training_config.load( - self._training_config_path - ) + self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load( - self._lay_down_config_path - ) + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) self.output_verbose_level = self._training_config.output_verbose_level self._env: Primaite @@ -132,9 +128,7 @@ class AgentSessionABC(ABC): "learning": {"total_episodes": None, "total_time_steps": None}, "evaluation": {"total_episodes": None, "total_time_steps": None}, "env": { - "training_config": self._training_config.to_dict( - json_serializable=True - ), + "training_config": self._training_config.to_dict(json_serializable=True), "lay_down_config": self._lay_down_config, }, } @@ -161,19 +155,11 @@ class AgentSessionABC(ABC): metadata_dict["end_datetime"] = datetime.now().isoformat() if not self.is_eval: - metadata_dict["learning"][ - "total_episodes" - ] = self._env.episode_count # noqa - metadata_dict["learning"][ - "total_time_steps" - ] = self._env.total_step_count # noqa + metadata_dict["learning"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["learning"]["total_time_steps"] = self._env.total_step_count # noqa else: - metadata_dict["evaluation"][ - "total_episodes" - ] = self._env.episode_count # noqa - metadata_dict["evaluation"][ - "total_time_steps" - ] = self._env.total_step_count # noqa + metadata_dict["evaluation"]["total_episodes"] = self._env.episode_count # noqa + metadata_dict["evaluation"]["total_time_steps"] = self._env.total_step_count # noqa filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -184,12 +170,9 @@ class AgentSessionABC(ABC): @abstractmethod def _setup(self): _LOGGER.info( - "Welcome to the Primary-level AI Training Environment " - f"(PrimAITE) (version: {primaite.__version__})" - ) - _LOGGER.info( - f"The output directory for this session is: {self.session_path}" + "Welcome to the Primary-level AI Training Environment " f"(PrimAITE) (version: {primaite.__version__})" ) + _LOGGER.info(f"The output directory for this session is: {self.session_path}") self._write_session_metadata_file() self._can_learn = True self._can_evaluate = False @@ -201,17 +184,11 @@ class AgentSessionABC(ABC): @abstractmethod def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Train the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ if self._can_learn: @@ -225,17 +202,11 @@ class AgentSessionABC(ABC): @abstractmethod def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ self._env.set_as_eval() # noqa @@ -281,9 +252,7 @@ class AgentSessionABC(ABC): else: # Session path does not exist - msg = ( - f"Failed to load PrimAITE Session, path does not exist: {path}" - ) + msg = f"Failed to load PrimAITE Session, path does not exist: {path}" _LOGGER.error(msg) raise FileNotFoundError(msg) pass @@ -354,17 +323,11 @@ class HardCodedAgentSessionABC(AgentSessionABC): def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Train the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ _LOGGER.warning("Deterministic agents cannot learn") @@ -375,27 +338,19 @@ class HardCodedAgentSessionABC(AgentSessionABC): def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ self._env.set_as_eval() # noqa self.is_eval = True - if not time_steps: - time_steps = self._training_config.num_steps + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes - if not episodes: - episodes = self._training_config.num_episodes obs = self._env.reset() for episode in range(episodes): # Reset env and collect initial observation diff --git a/src/primaite/agents/hardcoded_acl.py b/src/primaite/agents/hardcoded_acl.py index f70320f1..263ccbdc 100644 --- a/src/primaite/agents/hardcoded_acl.py +++ b/src/primaite/agents/hardcoded_acl.py @@ -14,10 +14,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): """An Agent Session class that implements a deterministic ACL agent.""" def _calculate_action(self, obs): - if ( - self._training_config.hard_coded_agent_view - == HardCodedAgentView.BASIC - ): + if self._training_config.hard_coded_agent_view == HardCodedAgentView.BASIC: # Basic view action using only the current observation return self._calculate_action_basic_view(obs) else: @@ -43,9 +40,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): port = green_ier.get_port() # Can be blocked by an ACL or by default (no allow rule exists) - if acl.is_blocked( - source_node_address, dest_node_address, protocol, port - ): + if acl.is_blocked(source_node_address, dest_node_address, protocol, port): blocked_green_iers[green_ier_id] = green_ier return blocked_green_iers @@ -64,9 +59,7 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): protocol = ier.get_protocol() # e.g. 'TCP' port = ier.get_port() - matching_rules = acl.get_relevant_rules( - source_node_address, dest_node_address, protocol, port - ) + matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules def get_blocking_acl_rules_for_ier(self, ier, acl, nodes): @@ -132,13 +125,9 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): dest_node_address = dest_node_id if protocol != "ANY": - protocol = services_list[ - protocol - 1 - ] # -1 as dont have to account for ANY in list of services + protocol = services_list[protocol - 1] # -1 as dont have to account for ANY in list of services - matching_rules = acl.get_relevant_rules( - source_node_address, dest_node_address, protocol, port - ) + matching_rules = acl.get_relevant_rules(source_node_address, dest_node_address, protocol, port) return matching_rules def get_allow_acl_rules( @@ -283,19 +272,12 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): action_decision = "DELETE" action_permission = "ALLOW" action_source_ip = rule.get_source_ip() - action_source_id = int( - get_node_of_ip(action_source_ip, self._env.nodes) - ) + action_source_id = int(get_node_of_ip(action_source_ip, self._env.nodes)) action_destination_ip = rule.get_dest_ip() - action_destination_id = int( - get_node_of_ip( - action_destination_ip, self._env.nodes - ) - ) + action_destination_id = int(get_node_of_ip(action_destination_ip, self._env.nodes)) action_protocol_name = rule.get_protocol() action_protocol = ( - self._env.services_list.index(action_protocol_name) - + 1 + self._env.services_list.index(action_protocol_name) + 1 ) # convert name e.g. 'TCP' to index action_port_name = rule.get_port() action_port = ( @@ -330,22 +312,16 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if not found_action: # Which Green IERS are blocked - blocked_green_iers = self.get_blocked_green_iers( - self._env.green_iers, self._env.acl, self._env.nodes - ) + blocked_green_iers = self.get_blocked_green_iers(self._env.green_iers, self._env.acl, self._env.nodes) for ier_key, ier in blocked_green_iers.items(): # Which ALLOW rules are allowing this IER (none) - allowing_rules = self.get_allow_acl_rules_for_ier( - ier, self._env.acl, self._env.nodes - ) + allowing_rules = self.get_allow_acl_rules_for_ier(ier, self._env.acl, self._env.nodes) # If there are no blocking rules, it may be being blocked by default # If there is already an allow rule node_id_to_check = int(ier.get_source_node_id()) service_name_to_check = ier.get_protocol() - service_id_to_check = self._env.services_list.index( - service_name_to_check - ) + service_id_to_check = self._env.services_list.index(service_name_to_check) # Service state of the the source node in the ier service_state = s[service_id_to_check][node_id_to_check - 1] @@ -413,31 +389,21 @@ class HardCodedACLAgent(HardCodedAgentSessionABC): if len(r_obs) == 4: # only 1 service s = [*s] - number_of_nodes = len( - [i for i in o if i != "NONE"] - ) # number of nodes (not links) + number_of_nodes = len([i for i in o if i != "NONE"]) # number of nodes (not links) for service_num, service_states in enumerate(s): - comprimised_states = [ - n for n, i in enumerate(service_states) if i == "COMPROMISED" - ] + comprimised_states = [n for n, i in enumerate(service_states) if i == "COMPROMISED"] if len(comprimised_states) == 0: # No states are COMPROMISED, try the next service continue - compromised_node = ( - np.random.choice(comprimised_states) + 1 - ) # +1 as 0 would be any + compromised_node = np.random.choice(comprimised_states) + 1 # +1 as 0 would be any action_decision = "DELETE" action_permission = "ALLOW" action_source_ip = compromised_node # Randomly select a destination ID to block - action_destination_ip = np.random.choice( - list(range(1, number_of_nodes + 1)) + ["ANY"] - ) + action_destination_ip = np.random.choice(list(range(1, number_of_nodes + 1)) + ["ANY"]) action_destination_ip = ( - int(action_destination_ip) - if action_destination_ip != "ANY" - else action_destination_ip + int(action_destination_ip) if action_destination_ip != "ANY" else action_destination_ip ) action_protocol = service_num + 1 # +1 as 0 is any # Randomly select a port to block diff --git a/src/primaite/agents/hardcoded_node.py b/src/primaite/agents/hardcoded_node.py index e258edb0..310fc178 100644 --- a/src/primaite/agents/hardcoded_node.py +++ b/src/primaite/agents/hardcoded_node.py @@ -1,9 +1,5 @@ from primaite.agents.agent import HardCodedAgentSessionABC -from primaite.agents.utils import ( - get_new_action, - transform_action_node_enum, - transform_change_obs_readable, -) +from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable class HardCodedNodeAgent(HardCodedAgentSessionABC): @@ -93,12 +89,8 @@ class HardCodedNodeAgent(HardCodedAgentSessionABC): if os_state == "OFF": action_node_id = x + 1 action_node_property = "OPERATING" - property_action = ( - "ON" # Why reset it when we can just turn it on - ) - action_service_index = ( - 0 # does nothing isn't relevant for operating state - ) + property_action = "ON" # Why reset it when we can just turn it on + action_service_index = 0 # does nothing isn't relevant for operating state action = [ action_node_id, action_node_property, diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 35ae1b53..2b6a5a83 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -3,9 +3,8 @@ from __future__ import annotations import json from datetime import datetime from pathlib import Path -from typing import Optional, Union +from typing import Union -import tensorflow as tf from ray.rllib.algorithms import Algorithm from ray.rllib.algorithms.a2c import A2CConfig from ray.rllib.algorithms.ppo import PPOConfig @@ -14,11 +13,7 @@ from ray.tune.registry import register_env from primaite import getLogger from primaite.agents.agent import AgentSessionABC -from primaite.common.enums import ( - AgentFramework, - AgentIdentifier, - DeepLearningFramework, -) +from primaite.common.enums import AgentFramework, AgentIdentifier from primaite.environment.primaite_env import Primaite _LOGGER = getLogger(__name__) @@ -49,10 +44,7 @@ class RLlibAgent(AgentSessionABC): def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.RLLIB: - msg = ( - f"Expected RLLIB agent_framework, " - f"got {self._training_config.agent_framework}" - ) + msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) raise ValueError(msg) if self._training_config.agent_identifier == AgentIdentifier.PPO: @@ -60,10 +52,7 @@ class RLlibAgent(AgentSessionABC): elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_config_class = A2CConfig else: - msg = ( - "Expected PPO or A2C agent_identifier, " - f"got {self._training_config.agent_identifier.value}" - ) + msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" _LOGGER.error(msg) raise ValueError(msg) self._agent_config: PPOConfig @@ -94,12 +83,8 @@ class RLlibAgent(AgentSessionABC): metadata_dict = json.load(file) metadata_dict["end_datetime"] = datetime.now().isoformat() - metadata_dict["total_episodes"] = self._current_result[ - "episodes_total" - ] - metadata_dict["total_time_steps"] = self._current_result[ - "timesteps_total" - ] + metadata_dict["total_episodes"] = self._current_result["episodes_total"] + metadata_dict["total_time_steps"] = self._current_result["timesteps_total"] filepath = self.session_path / "session_metadata.json" _LOGGER.debug(f"Updating Session Metadata file: {filepath}") @@ -122,9 +107,7 @@ class RLlibAgent(AgentSessionABC): ), ) - self._agent_config.training( - train_batch_size=self._training_config.num_steps - ) + self._agent_config.training(train_batch_size=self._training_config.num_steps) self._agent_config.framework(framework="tf") self._agent_config.rollouts( @@ -132,72 +115,41 @@ class RLlibAgent(AgentSessionABC): num_envs_per_worker=1, horizon=self._training_config.num_steps, ) - self._agent: Algorithm = self._agent_config.build( - logger_creator=_custom_log_creator(self.session_path) - ) + self._agent: Algorithm = self._agent_config.build(logger_creator=_custom_log_creator(self.learning_path)) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._current_result["episodes_total"] if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or ( - episode_count == self._training_config.num_episodes - ): + if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): self._agent.save(str(self.checkpoints_path)) def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ - # Temporarily override train_batch_size and horizon - if time_steps: - self._agent_config.train_batch_size = time_steps - self._agent_config.horizon = time_steps + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes - if not episodes: - episodes = self._training_config.num_episodes - _LOGGER.info( - f"Beginning learning for {episodes} episodes @" - f" {time_steps} time steps..." - ) + _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): self._current_result = self._agent.train() self._save_checkpoint() - if ( - self._training_config.deep_learning_framework - != DeepLearningFramework.TORCH - ): - policy = self._agent.get_policy() - tf.compat.v1.summary.FileWriter( - self.session_path / "ray_results", policy.get_session().graph - ) - super().learn() self._agent.stop() + super().learn() def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ raise NotImplementedError diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 8d5dd633..3161c93a 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Optional, Union +from typing import Union import numpy as np from stable_baselines3 import A2C, PPO @@ -21,10 +21,7 @@ class SB3Agent(AgentSessionABC): def __init__(self, training_config_path, lay_down_config_path): super().__init__(training_config_path, lay_down_config_path) if not self._training_config.agent_framework == AgentFramework.SB3: - msg = ( - f"Expected SB3 agent_framework, " - f"got {self._training_config.agent_framework}" - ) + msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}" _LOGGER.error(msg) raise ValueError(msg) if self._training_config.agent_identifier == AgentIdentifier.PPO: @@ -32,10 +29,7 @@ class SB3Agent(AgentSessionABC): elif self._training_config.agent_identifier == AgentIdentifier.A2C: self._agent_class = A2C else: - msg = ( - "Expected PPO or A2C agent_identifier, " - f"got {self._training_config.agent_identifier}" - ) + msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier}" _LOGGER.error(msg) raise ValueError(msg) @@ -64,19 +58,15 @@ class SB3Agent(AgentSessionABC): self._env, verbose=self.output_verbose_level, n_steps=self._training_config.num_steps, - tensorboard_log=self._tensorboard_log_path, + tensorboard_log=str(self._tensorboard_log_path), ) def _save_checkpoint(self): checkpoint_n = self._training_config.checkpoint_every_n_episodes episode_count = self._env.episode_count if checkpoint_n > 0 and episode_count > 0: - if (episode_count % checkpoint_n == 0) or ( - episode_count == self._training_config.num_episodes - ): - checkpoint_path = ( - self.checkpoints_path / f"sb3ppo_{episode_count}.zip" - ) + if (episode_count % checkpoint_n == 0) or (episode_count == self._training_config.num_episodes): + checkpoint_path = self.checkpoints_path / f"sb3ppo_{episode_count}.zip" self._agent.save(checkpoint_path) _LOGGER.debug(f"Saved agent checkpoint: {checkpoint_path}") @@ -85,58 +75,37 @@ class SB3Agent(AgentSessionABC): def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Train the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param kwargs: Any agent-specific key-word args to be passed. """ - if not time_steps: - time_steps = self._training_config.num_steps - - if not episodes: - episodes = self._training_config.num_episodes + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes self.is_eval = False - _LOGGER.info( - f"Beginning learning for {episodes} episodes @" - f" {time_steps} time steps..." - ) + _LOGGER.info(f"Beginning learning for {episodes} episodes @" f" {time_steps} time steps...") for i in range(episodes): self._agent.learn(total_timesteps=time_steps) self._save_checkpoint() - - self.close() + self._env.reset() + self._env.close() super().learn() def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, deterministic: bool = True, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of steps per episode. Optional. If not - passed, the value from the training config will be used. - :param episodes: The number of episodes. Optional. If not - passed, the value from the training config will be used. :param deterministic: Whether the evaluation is deterministic. :param kwargs: Any agent-specific key-word args to be passed. """ - if not time_steps: - time_steps = self._training_config.num_steps - - if not episodes: - episodes = self._training_config.num_episodes + time_steps = self._training_config.num_steps + episodes = self._training_config.num_episodes self._env.set_as_eval() self.is_eval = True if deterministic: @@ -144,19 +113,18 @@ class SB3Agent(AgentSessionABC): else: deterministic_str = "non-deterministic" _LOGGER.info( - f"Beginning {deterministic_str} evaluation for " - f"{episodes} episodes @ {time_steps} time steps..." + f"Beginning {deterministic_str} evaluation for " f"{episodes} episodes @ {time_steps} time steps..." ) for episode in range(episodes): obs = self._env.reset() for step in range(time_steps): - action, _states = self._agent.predict( - obs, deterministic=deterministic - ) + action, _states = self._agent.predict(obs, deterministic=deterministic) if isinstance(action, np.ndarray): action = np.int64(action) obs, rewards, done, info = self._env.step(action) + self._env.reset() + self._env.close() super().evaluate() @classmethod diff --git a/src/primaite/agents/simple.py b/src/primaite/agents/simple.py index cf333b1e..5a6c9da5 100644 --- a/src/primaite/agents/simple.py +++ b/src/primaite/agents/simple.py @@ -1,9 +1,5 @@ from primaite.agents.agent import HardCodedAgentSessionABC -from primaite.agents.utils import ( - get_new_action, - transform_action_acl_enum, - transform_action_node_enum, -) +from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum class RandomAgent(HardCodedAgentSessionABC): diff --git a/src/primaite/agents/utils.py b/src/primaite/agents/utils.py index c3e67fdf..8c59faf7 100644 --- a/src/primaite/agents/utils.py +++ b/src/primaite/agents/utils.py @@ -24,9 +24,7 @@ def transform_action_node_readable(action): if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif ( - action_node_property == "OS" or action_node_property == "SERVICE" - ) and action[2] <= 1: + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -117,11 +115,7 @@ def is_valid_acl_action(action): if action_decision == "NONE": return False - if ( - action_source_id == action_destination_id - and action_source_id != "ANY" - and action_destination_id != "ANY" - ): + if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": # ACL rule towards itself return False if action_permission == "DENY": @@ -173,9 +167,7 @@ def transform_change_obs_readable(obs): for service in range(3, obs.shape[1]): # Links bit/s don't have a service state - service_states = [ - SoftwareState(i).name if i <= 4 else i for i in obs[:, service] - ] + service_states = [SoftwareState(i).name if i <= 4 else i for i in obs[:, service]] new_obs.append(service_states) return new_obs @@ -247,9 +239,7 @@ def convert_to_old_obs(obs, num_nodes=10, num_links=10, num_services=1): return new_obs -def describe_obs_change( - obs1, obs2, num_nodes=10, num_links=10, num_services=1 -): +def describe_obs_change(obs1, obs2, num_nodes=10, num_links=10, num_services=1): """ Return string describing change between two observations. @@ -291,16 +281,9 @@ def _describe_obs_change_helper(obs_change, is_link): TODO: Typehint params and return. """ # Indexes where a change has occured, not including 0th index - index_changed = [ - i for i in range(1, len(obs_change)) if obs_change[i] != -1 - ] + index_changed = [i for i in range(1, len(obs_change)) if obs_change[i] != -1] # Node pol types, Indexes >= 3 are service nodes - NodePOLTypes = [ - NodePOLType(i).name - if i < 3 - else NodePOLType(3).name + " " + str(i - 3) - for i in index_changed - ] + NodePOLTypes = [NodePOLType(i).name if i < 3 else NodePOLType(3).name + " " + str(i - 3) for i in index_changed] # Account for hardware states, software sattes and links states = [ LinkStatus(obs_change[i]).name @@ -367,9 +350,7 @@ def transform_action_node_readable(action): if action_node_property == "OPERATING": property_action = NodeHardwareAction(action[2]).name - elif ( - action_node_property == "OS" or action_node_property == "SERVICE" - ) and action[2] <= 1: + elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[2] <= 1: property_action = NodeSoftwareAction(action[2]).name else: property_action = "NONE" @@ -397,9 +378,7 @@ def node_action_description(action): if property_action == "NONE": return "" if node_property == "OPERATING" or node_property == "OS": - description = ( - f"NODE {node_id}, {node_property}, SET TO {property_action}" - ) + description = f"NODE {node_id}, {node_property}, SET TO {property_action}" elif node_property == "SERVICE": description = f"NODE {node_id} FROM SERVICE {service_id}, SET TO {property_action}" else: @@ -522,11 +501,7 @@ def is_valid_acl_action(action): if action_decision == "NONE": return False - if ( - action_source_id == action_destination_id - and action_source_id != "ANY" - and action_destination_id != "ANY" - ): + if action_source_id == action_destination_id and action_source_id != "ANY" and action_destination_id != "ANY": # ACL rule towards itself return False if action_permission == "DENY": diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 0431174f..40e8cf0d 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -56,9 +56,7 @@ def logs(last_n: Annotated[int, typer.Option("-n")]): print(re.sub(r"\n*", "", line)) -_LogLevel = Enum( - "LogLevel", {k: k for k in logging._levelToName.values()} -) # noqa +_LogLevel = Enum("LogLevel", {k: k for k in logging._levelToName.values()}) # noqa @app.command() @@ -124,21 +122,12 @@ def setup(overwrite_existing: bool = True): app_dirs = PlatformDirs(appname="primaite") app_dirs.user_config_path.mkdir(exist_ok=True, parents=True) user_config_path = app_dirs.user_config_path / "primaite_config.yaml" - pkg_config_path = Path( - pkg_resources.resource_filename( - "primaite", "setup/_package_data/primaite_config.yaml" - ) - ) + pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) shutil.copy2(pkg_config_path, user_config_path) from primaite import getLogger - from primaite.setup import ( - old_installation_clean_up, - reset_demo_notebooks, - reset_example_configs, - setup_app_dirs, - ) + from primaite.setup import old_installation_clean_up, reset_demo_notebooks, reset_example_configs, setup_app_dirs _LOGGER = getLogger(__name__) @@ -188,9 +177,7 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None): @app.command() -def plotly_template( - template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None -): +def plotly_template(template: Annotated[Optional[PlotlyTemplate], typer.Argument()] = None): """ View or set the plotly template for Session plots. @@ -208,14 +195,10 @@ def plotly_template( primaite_config = yaml.safe_load(file) if template: - primaite_config["session"]["outputs"]["plots"][ - "template" - ] = template.value + primaite_config["session"]["outputs"]["plots"]["template"] = template.value with open(user_config_path, "w") as file: yaml.dump(primaite_config, file) print(f"PrimAITE plotly template: {template.value}") else: - template = primaite_config["session"]["outputs"]["plots"][ - "template" - ] + template = primaite_config["session"]["outputs"]["plots"]["template"] print(f"PrimAITE plotly template: {template}") diff --git a/src/primaite/config/lay_down_config.py b/src/primaite/config/lay_down_config.py index ae067228..08f77b2f 100644 --- a/src/primaite/config/lay_down_config.py +++ b/src/primaite/config/lay_down_config.py @@ -8,14 +8,10 @@ from primaite import getLogger, USERS_CONFIG_DIR _LOGGER = getLogger(__name__) -_EXAMPLE_LAY_DOWN: Final[Path] = ( - USERS_CONFIG_DIR / "example_config" / "lay_down" -) +_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down" -def convert_legacy_lay_down_config_dict( - legacy_config_dict: Dict[str, Any] -) -> Dict[str, Any]: +def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]: """ Convert a legacy lay down config dict to the new format. diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 84dd3cc8..3e0f26ca 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -20,9 +20,7 @@ from primaite.common.enums import ( _LOGGER = getLogger(__name__) -_EXAMPLE_TRAINING: Final[Path] = ( - USERS_CONFIG_DIR / "example_config" / "training" -) +_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" def main_training_config_path() -> Path: @@ -68,9 +66,7 @@ class TrainingConfig: checkpoint_every_n_episodes: int = 5 "The agent will save a checkpoint every n episodes" - observation_space: dict = field( - default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]} - ) + observation_space: dict = field(default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}) "The observation space config dict" time_delay: int = 10 @@ -180,9 +176,7 @@ class TrainingConfig: "The time taken to scan the file system" @classmethod - def from_dict( - cls, config_dict: Dict[str, Union[str, int, bool]] - ) -> TrainingConfig: + def from_dict(cls, config_dict: Dict[str, Union[str, int, bool]]) -> TrainingConfig: """ Create an instance of TrainingConfig from a dict. @@ -238,9 +232,7 @@ class TrainingConfig: return tc -def load( - file_path: Union[str, Path], legacy_file: bool = False -) -> TrainingConfig: +def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConfig: """ Read in a training config yaml file. @@ -271,10 +263,7 @@ def load( try: return TrainingConfig.from_dict(config) except TypeError as e: - msg = ( - f"Error when creating an instance of {TrainingConfig} " - f"from the training config file {file_path}" - ) + msg = f"Error when creating an instance of {TrainingConfig} " f"from the training config file {file_path}" _LOGGER.critical(msg, exc_info=True) raise e msg = f"Cannot load the training config as it does not exist: {file_path}" @@ -314,9 +303,7 @@ def convert_legacy_training_config_dict( "output_verbose_level": output_verbose_level.name, } session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"} - legacy_config_dict["sessionType"] = session_type_map[ - legacy_config_dict["sessionType"] - ] + legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]] for legacy_key, value in legacy_config_dict.items(): new_key = _get_new_key_from_legacy(legacy_key) if new_key: diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 6893125e..d0d5d46e 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -77,9 +77,7 @@ class NodeLinkTable(AbstractObservationComponent): ) # 3. Initialise Observation with zeroes - self.current_observation = np.zeros( - observation_shape, dtype=self._DATA_TYPE - ) + self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) def update(self): """Update the observation based on current environment state. @@ -94,12 +92,8 @@ class NodeLinkTable(AbstractObservationComponent): self.current_observation[item_index][0] = int(node.node_id) self.current_observation[item_index][1] = node.hardware_state.value if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): - self.current_observation[item_index][ - 2 - ] = node.software_state.value - self.current_observation[item_index][ - 3 - ] = node.file_system_state_observed.value + self.current_observation[item_index][2] = node.software_state.value + self.current_observation[item_index][3] = node.file_system_state_observed.value else: self.current_observation[item_index][2] = 0 self.current_observation[item_index][3] = 0 @@ -107,9 +101,7 @@ class NodeLinkTable(AbstractObservationComponent): if isinstance(node, ServiceNode): for service in self.env.services_list: if node.has_service(service): - self.current_observation[item_index][ - service_index - ] = node.get_service_state(service).value + self.current_observation[item_index][service_index] = node.get_service_state(service).value else: self.current_observation[item_index][service_index] = 0 service_index += 1 @@ -129,9 +121,7 @@ class NodeLinkTable(AbstractObservationComponent): protocol_list = link.get_protocol_list() protocol_index = 0 for protocol in protocol_list: - self.current_observation[item_index][ - protocol_index + 4 - ] = protocol.get_load() + self.current_observation[item_index][protocol_index + 4] = protocol.get_load() protocol_index += 1 item_index += 1 @@ -203,9 +193,7 @@ class NodeStatuses(AbstractObservationComponent): if isinstance(node, ServiceNode): for i, service in enumerate(self.env.services_list): if node.has_service(service): - service_states[i] = node.get_service_state( - service - ).value + service_states[i] = node.get_service_state(service).value obs.extend( [ hardware_state, @@ -269,11 +257,7 @@ class LinkTrafficLevels(AbstractObservationComponent): self._entries_per_link = self.env.num_services # 1. Define the shape of your observation space component - shape = ( - [self._quantisation_levels] - * self.env.num_links - * self._entries_per_link - ) + shape = [self._quantisation_levels] * self.env.num_links * self._entries_per_link # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) @@ -292,9 +276,7 @@ class LinkTrafficLevels(AbstractObservationComponent): if self._combine_service_traffic: loads = [link.get_current_load()] else: - loads = [ - protocol.get_load() for protocol in link.protocol_list - ] + loads = [protocol.get_load() for protocol in link.protocol_list] for load in loads: if load <= 0: @@ -302,9 +284,7 @@ class LinkTrafficLevels(AbstractObservationComponent): elif load >= bandwidth: traffic_level = self._quantisation_levels - 1 else: - traffic_level = (load / bandwidth) // ( - 1 / (self._quantisation_levels - 2) - ) + 1 + traffic_level = (load / bandwidth) // (1 / (self._quantisation_levels - 2)) + 1 obs.append(int(traffic_level)) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index ea8f82d4..df51e21e 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -12,13 +12,11 @@ from matplotlib import pyplot as plt from primaite import getLogger from primaite.acl.access_control_list import AccessControlList -from primaite.agents.utils import ( - is_valid_acl_action_extra, - is_valid_node_action, -) +from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action from primaite.common.custom_typing import NodeUnion from primaite.common.enums import ( ActionType, + AgentFramework, FileSystemState, HardwareState, NodePOLInitiator, @@ -37,18 +35,13 @@ from primaite.environment.reward import calculate_reward_function from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node import Node -from primaite.nodes.node_state_instruction_green import ( - NodeStateInstructionGreen, -) +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.passive_node import PassiveNode from primaite.nodes.service_node import ServiceNode from primaite.pol.green_pol import apply_iers, apply_node_pol from primaite.pol.ier import IER -from primaite.pol.red_agent_pol import ( - apply_red_agent_iers, - apply_red_agent_node_pol, -) +from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_node_pol from primaite.transactions.transaction import Transaction from primaite.utils.session_output_writer import SessionOutputWriter @@ -85,9 +78,7 @@ class Primaite(Env): self._training_config_path = training_config_path self._lay_down_config_path = lay_down_config_path - self.training_config: TrainingConfig = training_config.load( - training_config_path - ) + self.training_config: TrainingConfig = training_config.load(training_config_path) _LOGGER.info(f"Using: {str(self.training_config)}") # Number of steps in an episode @@ -238,25 +229,22 @@ class Primaite(Env): self.action_dict = self.create_node_and_acl_action_dict() self.action_space = spaces.Discrete(len(self.action_dict)) else: - _LOGGER.error( - f"Invalid action type selected: {self.training_config.action_type}" - ) + _LOGGER.error(f"Invalid action type selected: {self.training_config.action_type}") - self.episode_av_reward_writer = SessionOutputWriter( - self, transaction_writer=False, learning_session=True - ) - self.transaction_writer = SessionOutputWriter( - self, transaction_writer=True, learning_session=True - ) + self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=True) + self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=True) + + @property + def actual_episode_count(self) -> int: + """Shifts the episode_count by -1 for RLlib.""" + if self.training_config.agent_framework is AgentFramework.RLLIB: + return self.episode_count - 1 + return self.episode_count def set_as_eval(self): """Set the writers to write to eval directories.""" - self.episode_av_reward_writer = SessionOutputWriter( - self, transaction_writer=False, learning_session=False - ) - self.transaction_writer = SessionOutputWriter( - self, transaction_writer=True, learning_session=False - ) + self.episode_av_reward_writer = SessionOutputWriter(self, transaction_writer=False, learning_session=False) + self.transaction_writer = SessionOutputWriter(self, transaction_writer=True, learning_session=False) self.episode_count = 0 self.step_count = 0 self.total_step_count = 0 @@ -268,8 +256,8 @@ class Primaite(Env): Returns: Environment observation space (reset) """ - if self.episode_count > 0: - csv_data = self.episode_count, self.average_reward + if self.actual_episode_count > 0: + csv_data = self.actual_episode_count, self.average_reward self.episode_av_reward_writer.write(csv_data) self.episode_count += 1 @@ -291,6 +279,7 @@ class Primaite(Env): # Update observations space and return self.update_environent_obs() + return self.env_obs def step(self, action): @@ -319,9 +308,7 @@ class Primaite(Env): link.clear_traffic() # Create a Transaction (metric) object for this step - transaction = Transaction( - self.agent_identifier, self.episode_count, self.step_count - ) + transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) # Load the initial observation space into the transaction transaction.obs_space_pre = copy.deepcopy(self.env_obs) # Load the action space into the transaction @@ -350,9 +337,7 @@ class Primaite(Env): self.nodes_post_pol = copy.deepcopy(self.nodes) self.links_post_pol = copy.deepcopy(self.links) # Reference - apply_node_pol( - self.nodes_reference, self.node_pol, self.step_count - ) # Node PoL + apply_node_pol(self.nodes_reference, self.node_pol, self.step_count) # Node PoL apply_iers( self.network_reference, self.nodes_reference, @@ -371,9 +356,7 @@ class Primaite(Env): self.acl, self.step_count, ) - apply_red_agent_node_pol( - self.nodes, self.red_iers, self.red_node_pol, self.step_count - ) + apply_red_agent_node_pol(self.nodes, self.red_iers, self.red_node_pol, self.step_count) # Take snapshots of nodes and links self.nodes_post_red = copy.deepcopy(self.nodes) self.links_post_red = copy.deepcopy(self.links) @@ -389,11 +372,7 @@ class Primaite(Env): self.step_count, self.training_config, ) - _LOGGER.debug( - f"Episode: {self.episode_count}, " - f"Step {self.step_count}, " - f"Reward: {reward}" - ) + _LOGGER.debug(f"Episode: {self.actual_episode_count}, " f"Step {self.step_count}, " f"Reward: {reward}") self.total_reward += reward if self.step_count == self.episode_steps: self.average_reward = self.total_reward / self.step_count @@ -401,10 +380,7 @@ class Primaite(Env): # For evaluation, need to trigger the done value = True when # step count is reached in order to prevent neverending episode done = True - _LOGGER.info( - f"Episode: {self.episode_count}, " - f"Average Reward: {self.average_reward}" - ) + _LOGGER.info(f"Episode: {self.actual_episode_count}, " f"Average Reward: {self.average_reward}") # Load the reward into the transaction transaction.reward = reward @@ -417,11 +393,21 @@ class Primaite(Env): transaction.obs_space_post = copy.deepcopy(self.env_obs) # Write transaction to file - self.transaction_writer.write(transaction) + if self.actual_episode_count > 0: + self.transaction_writer.write(transaction) # Return return self.env_obs, reward, done, self.step_info + def close(self): + """Override parent close and close writers.""" + # Close files if last episode/step + # if self.can_finish: + super().close() + + self.transaction_writer.close() + self.episode_av_reward_writer.close() + def init_acl(self): """Initialise the Access Control List.""" self.acl.remove_all_rules() @@ -431,12 +417,7 @@ class Primaite(Env): for link_key, link_value in self.links.items(): _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - _LOGGER.debug( - " Protocol: " - + protocol.get_name().name - + ", Load: " - + str(protocol.get_load()) - ) + _LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) def interpret_action_and_apply(self, _action): """ @@ -450,13 +431,9 @@ class Primaite(Env): self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 6 - ): # ACL actions in multidiscrete form have len 6 + elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 4 - ): # Node actions in multdiscrete (array) from have len 4 + elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: _LOGGER.error("Invalid action type found") @@ -541,10 +518,7 @@ class Primaite(Env): elif property_action == 2: # Repair # You cannot repair a destroyed file system - it needs restoring - if ( - node.file_system_state_actual - != FileSystemState.DESTROYED - ): + if node.file_system_state_actual != FileSystemState.DESTROYED: node.set_file_system_state(FileSystemState.REPAIRING) elif property_action == 3: # Restore @@ -587,9 +561,7 @@ class Primaite(Env): acl_rule_source = "ANY" else: node = list(self.nodes.values())[action_source_ip - 1] - if isinstance(node, ServiceNode) or isinstance( - node, ActiveNode - ): + if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): acl_rule_source = node.ip_address else: return @@ -598,9 +570,7 @@ class Primaite(Env): acl_rule_destination = "ANY" else: node = list(self.nodes.values())[action_destination_ip - 1] - if isinstance(node, ServiceNode) or isinstance( - node, ActiveNode - ): + if isinstance(node, ServiceNode) or isinstance(node, ActiveNode): acl_rule_destination = node.ip_address else: return @@ -685,9 +655,7 @@ class Primaite(Env): :return: The observation space, initial observation (zeroed out array with the correct shape) :rtype: Tuple[spaces.Space, np.ndarray] """ - self.obs_handler = ObservationsHandler.from_config( - self, self.obs_config - ) + self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation @@ -794,9 +762,7 @@ class Primaite(Env): service_protocol = service["name"] service_port = service["port"] service_state = SoftwareState[service["state"]] - node.add_service( - Service(service_protocol, service_port, service_state) - ) + node.add_service(Service(service_protocol, service_port, service_state)) else: # Bad formatting pass @@ -849,9 +815,7 @@ class Primaite(Env): dest_node_ref: Node = self.nodes_reference[link_destination] # Add link to network (reference) - self.network_reference.add_edge( - source_node_ref, dest_node_ref, id=link_name - ) + self.network_reference.add_edge(source_node_ref, dest_node_ref, id=link_name) # Add link to link dictionary (reference) self.links_reference[link_name] = Link( @@ -1126,9 +1090,7 @@ class Primaite(Env): # All nodes have these parameters node_id = item["node_id"] node_class = item["node_class"] - node_hardware_state: HardwareState = HardwareState[ - item["hardware_state"] - ] + node_hardware_state: HardwareState = HardwareState[item["hardware_state"]] node: NodeUnion = self.nodes[node_id] node_ref = self.nodes_reference[node_id] @@ -1249,11 +1211,7 @@ class Primaite(Env): # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other - new_node_action_dict = { - k + len(acl_action_dict) - 1: v - for k, v in node_action_dict.items() - if k != 0 - } + new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0} # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} diff --git a/src/primaite/environment/reward.py b/src/primaite/environment/reward.py index 4dd0550e..19094a18 100644 --- a/src/primaite/environment/reward.py +++ b/src/primaite/environment/reward.py @@ -41,29 +41,19 @@ def calculate_reward_function( reference_node = reference_nodes[node_key] # Hardware State - reward_value += score_node_operating_state( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_operating_state(final_node, initial_node, reference_node, config_values) # Software State - if isinstance(final_node, ActiveNode) or isinstance( - final_node, ServiceNode - ): - reward_value += score_node_os_state( - final_node, initial_node, reference_node, config_values - ) + if isinstance(final_node, ActiveNode) or isinstance(final_node, ServiceNode): + reward_value += score_node_os_state(final_node, initial_node, reference_node, config_values) # Service State if isinstance(final_node, ServiceNode): - reward_value += score_node_service_state( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_service_state(final_node, initial_node, reference_node, config_values) # File System State if isinstance(final_node, ActiveNode): - reward_value += score_node_file_system( - final_node, initial_node, reference_node, config_values - ) + reward_value += score_node_file_system(final_node, initial_node, reference_node, config_values) # Go through each red IER - penalise if it is running for ier_key, ier_value in red_iers.items(): @@ -82,10 +72,7 @@ def calculate_reward_function( if step_count >= start_step and step_count <= stop_step: reference_blocked = not reference_ier.get_is_running() live_blocked = not ier_value.get_is_running() - ier_reward = ( - config_values.green_ier_blocked - * ier_value.get_mission_criticality() - ) + ier_reward = config_values.green_ier_blocked * ier_value.get_mission_criticality() if live_blocked and not reference_blocked: reward_value += ier_reward @@ -107,9 +94,7 @@ def calculate_reward_function( return reward_value -def score_node_operating_state( - final_node, initial_node, reference_node, config_values -): +def score_node_operating_state(final_node, initial_node, reference_node, config_values): """ Calculates score relating to the hardware state of a node. @@ -158,9 +143,7 @@ def score_node_operating_state( return score -def score_node_os_state( - final_node, initial_node, reference_node, config_values -): +def score_node_os_state(final_node, initial_node, reference_node, config_values): """ Calculates score relating to the Software State of a node. @@ -211,9 +194,7 @@ def score_node_os_state( return score -def score_node_service_state( - final_node, initial_node, reference_node, config_values -): +def score_node_service_state(final_node, initial_node, reference_node, config_values): """ Calculates score relating to the service state(s) of a node. @@ -285,9 +266,7 @@ def score_node_service_state( return score -def score_node_file_system( - final_node, initial_node, reference_node, config_values -): +def score_node_file_system(final_node, initial_node, reference_node, config_values): """ Calculates score relating to the file system state of a node. diff --git a/src/primaite/links/link.py b/src/primaite/links/link.py index 054f4c34..90235e9f 100644 --- a/src/primaite/links/link.py +++ b/src/primaite/links/link.py @@ -8,9 +8,7 @@ from primaite.common.protocol import Protocol class Link(object): """Link class.""" - def __init__( - self, _id, _bandwidth, _source_node_name, _dest_node_name, _services - ): + def __init__(self, _id, _bandwidth, _source_node_name, _dest_node_name, _services): """ Init. diff --git a/src/primaite/main.py b/src/primaite/main.py index 556c5ec3..7b1d7ab3 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -32,11 +32,7 @@ if __name__ == "__main__": parser.add_argument("--ldc") args = parser.parse_args() if not args.tc: - _LOGGER.error( - "Please provide a training config file using the --tc " "argument" - ) + _LOGGER.error("Please provide a training config file using the --tc " "argument") if not args.ldc: - _LOGGER.error( - "Please provide a lay down config file using the --ldc " "argument" - ) + _LOGGER.error("Please provide a lay down config file using the --ldc " "argument") run(training_config_path=args.tc, lay_down_config_path=args.ldc) diff --git a/src/primaite/nodes/active_node.py b/src/primaite/nodes/active_node.py index b1c3f57c..07a0ea0a 100644 --- a/src/primaite/nodes/active_node.py +++ b/src/primaite/nodes/active_node.py @@ -3,13 +3,7 @@ import logging from typing import Final -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.config.training_config import TrainingConfig from primaite.nodes.node import Node @@ -44,9 +38,7 @@ class ActiveNode(Node): :param file_system_state: The node file system state :param config_values: The config values """ - super().__init__( - node_id, name, node_type, priority, hardware_state, config_values - ) + super().__init__(node_id, name, node_type, priority, hardware_state, config_values) self.ip_address: str = ip_address # Related to Software self._software_state: SoftwareState = software_state @@ -87,9 +79,7 @@ class ActiveNode(Node): f"Node.software_state:{self._software_state}" ) - def set_software_state_if_not_compromised( - self, software_state: SoftwareState - ): + def set_software_state_if_not_compromised(self, software_state: SoftwareState): """ Sets Software State if the node is not compromised. @@ -100,9 +90,7 @@ class ActiveNode(Node): if self._software_state != SoftwareState.COMPROMISED: self._software_state = software_state if software_state == SoftwareState.PATCHING: - self.patching_count = ( - self.config_values.os_patching_duration - ) + self.patching_count = self.config_values.os_patching_duration else: _LOGGER.info( f"The Nodes hardware state is OFF so OS State cannot be changed." @@ -129,14 +117,10 @@ class ActiveNode(Node): self.file_system_state_actual = file_system_state if file_system_state == FileSystemState.REPAIRING: - self.file_system_action_count = ( - self.config_values.file_system_repairing_limit - ) + self.file_system_action_count = self.config_values.file_system_repairing_limit self.file_system_state_observed = FileSystemState.REPAIRING elif file_system_state == FileSystemState.RESTORING: - self.file_system_action_count = ( - self.config_values.file_system_restoring_limit - ) + self.file_system_action_count = self.config_values.file_system_restoring_limit self.file_system_state_observed = FileSystemState.RESTORING elif file_system_state == FileSystemState.GOOD: self.file_system_state_observed = FileSystemState.GOOD @@ -149,9 +133,7 @@ class ActiveNode(Node): f"Node.file_system_state.actual:{self.file_system_state_actual}" ) - def set_file_system_state_if_not_compromised( - self, file_system_state: FileSystemState - ): + def set_file_system_state_if_not_compromised(self, file_system_state: FileSystemState): """ Sets the file system state (actual and observed) if not in a compromised state. @@ -168,14 +150,10 @@ class ActiveNode(Node): self.file_system_state_actual = file_system_state if file_system_state == FileSystemState.REPAIRING: - self.file_system_action_count = ( - self.config_values.file_system_repairing_limit - ) + self.file_system_action_count = self.config_values.file_system_repairing_limit self.file_system_state_observed = FileSystemState.REPAIRING elif file_system_state == FileSystemState.RESTORING: - self.file_system_action_count = ( - self.config_values.file_system_restoring_limit - ) + self.file_system_action_count = self.config_values.file_system_restoring_limit self.file_system_state_observed = FileSystemState.RESTORING elif file_system_state == FileSystemState.GOOD: self.file_system_state_observed = FileSystemState.GOOD @@ -191,9 +169,7 @@ class ActiveNode(Node): def start_file_system_scan(self): """Starts a file system scan.""" self.file_system_scanning = True - self.file_system_scanning_count = ( - self.config_values.file_system_scanning_limit - ) + self.file_system_scanning_count = self.config_values.file_system_scanning_limit def update_file_system_state(self): """Updates file system status based on scanning/restore/repair cycle.""" @@ -212,10 +188,7 @@ class ActiveNode(Node): self.file_system_state_observed = FileSystemState.GOOD # Scanning updates - if ( - self.file_system_scanning == True - and self.file_system_scanning_count < 0 - ): + if self.file_system_scanning == True and self.file_system_scanning_count < 0: self.file_system_state_observed = self.file_system_state_actual self.file_system_scanning = False self.file_system_scanning_count = 0 diff --git a/src/primaite/nodes/node_state_instruction_green.py b/src/primaite/nodes/node_state_instruction_green.py index 04681807..2b1d94be 100644 --- a/src/primaite/nodes/node_state_instruction_green.py +++ b/src/primaite/nodes/node_state_instruction_green.py @@ -32,9 +32,7 @@ class NodeStateInstructionGreen(object): self.end_step = _end_step self.node_id = _node_id self.node_pol_type = _node_pol_type - self.service_name = ( - _service_name # Not used when not a service instruction - ) + self.service_name = _service_name # Not used when not a service instruction self.state = _state def get_start_step(self): diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index ba35067c..7f62fe24 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -42,9 +42,7 @@ class NodeStateInstructionRed(object): self.target_node_id = _target_node_id self.initiator = _pol_initiator self.pol_type: NodePOLType = _pol_type - self.service_name = ( - pol_protocol # Not used when not a service instruction - ) + self.service_name = pol_protocol # Not used when not a service instruction self.state = _pol_state self.source_node_id = _pol_source_node_id self.source_node_service = _pol_source_node_service diff --git a/src/primaite/nodes/passive_node.py b/src/primaite/nodes/passive_node.py index 6515097a..9aa5c7d7 100644 --- a/src/primaite/nodes/passive_node.py +++ b/src/primaite/nodes/passive_node.py @@ -28,9 +28,7 @@ class PassiveNode(Node): :param config_values: Config values. """ # Pass through to Super for now - super().__init__( - node_id, name, node_type, priority, hardware_state, config_values - ) + super().__init__(node_id, name, node_type, priority, hardware_state, config_values) @property def ip_address(self) -> str: diff --git a/src/primaite/nodes/service_node.py b/src/primaite/nodes/service_node.py index 6dcff73e..5d69df92 100644 --- a/src/primaite/nodes/service_node.py +++ b/src/primaite/nodes/service_node.py @@ -3,13 +3,7 @@ import logging from typing import Dict, Final -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -110,9 +104,7 @@ class ServiceNode(ActiveNode): return False return False - def set_service_state( - self, protocol_name: str, software_state: SoftwareState - ): + def set_service_state(self, protocol_name: str, software_state: SoftwareState): """ Sets the software_state of a service (protocol) on the node. @@ -130,9 +122,7 @@ class ServiceNode(ActiveNode): ) or software_state != SoftwareState.COMPROMISED: service_value.software_state = software_state if software_state == SoftwareState.PATCHING: - service_value.patching_count = ( - self.config_values.service_patching_duration - ) + service_value.patching_count = self.config_values.service_patching_duration else: _LOGGER.info( f"The Nodes hardware state is OFF so the state of a service " @@ -143,9 +133,7 @@ class ServiceNode(ActiveNode): f"Node.services[].software_state:{software_state}" ) - def set_service_state_if_not_compromised( - self, protocol_name: str, software_state: SoftwareState - ): + def set_service_state_if_not_compromised(self, protocol_name: str, software_state: SoftwareState): """ Sets the software_state of a service (protocol) on the node. @@ -161,9 +149,7 @@ class ServiceNode(ActiveNode): if service_value.software_state != SoftwareState.COMPROMISED: service_value.software_state = software_state if software_state == SoftwareState.PATCHING: - service_value.patching_count = ( - self.config_values.service_patching_duration - ) + service_value.patching_count = self.config_values.service_patching_duration else: _LOGGER.info( f"The Nodes hardware state is OFF so the state of a service " diff --git a/src/primaite/pol/green_pol.py b/src/primaite/pol/green_pol.py index aeae7add..e9dfef8c 100644 --- a/src/primaite/pol/green_pol.py +++ b/src/primaite/pol/green_pol.py @@ -6,17 +6,10 @@ from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( - HardwareState, - NodePOLType, - NodeType, - SoftwareState, -) +from primaite.common.enums import HardwareState, NodePOLType, NodeType, SoftwareState from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode -from primaite.nodes.node_state_instruction_green import ( - NodeStateInstructionGreen, -) +from primaite.nodes.node_state_instruction_green import NodeStateInstructionGreen from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from primaite.nodes.service_node import ServiceNode from primaite.pol.ier import IER @@ -93,9 +86,7 @@ def apply_iers( and source_node.software_state != SoftwareState.PATCHING ): if source_node.has_service(protocol): - if source_node.service_running( - protocol - ) and not source_node.service_is_overwhelmed(protocol): + if source_node.service_running(protocol) and not source_node.service_is_overwhelmed(protocol): source_valid = True else: source_valid = False @@ -110,10 +101,7 @@ def apply_iers( # 2. Check the dest node situation if dest_node.node_type == NodeType.SWITCH: # It's a switch - if ( - dest_node.hardware_state == HardwareState.ON - and dest_node.software_state != SoftwareState.PATCHING - ): + if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING: dest_valid = True else: # IER no longer valid @@ -123,14 +111,9 @@ def apply_iers( pass else: # It's not a switch or an actuator (so active node) - if ( - dest_node.hardware_state == HardwareState.ON - and dest_node.software_state != SoftwareState.PATCHING - ): + if dest_node.hardware_state == HardwareState.ON and dest_node.software_state != SoftwareState.PATCHING: if dest_node.has_service(protocol): - if dest_node.service_running( - protocol - ) and not dest_node.service_is_overwhelmed(protocol): + if dest_node.service_running(protocol) and not dest_node.service_is_overwhelmed(protocol): dest_valid = True else: dest_valid = False @@ -143,9 +126,7 @@ def apply_iers( dest_valid = False # 3. Check that the ACL doesn't block it - acl_block = acl.is_blocked( - source_node.ip_address, dest_node.ip_address, protocol, port - ) + acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port) if acl_block: if _VERBOSE: print( @@ -176,10 +157,7 @@ def apply_iers( # We might have a switch in the path, so check all nodes are operational for node in path_node_list: - if ( - node.hardware_state != HardwareState.ON - or node.software_state == SoftwareState.PATCHING - ): + if node.hardware_state != HardwareState.ON or node.software_state == SoftwareState.PATCHING: path_valid = False if path_valid: @@ -191,15 +169,11 @@ def apply_iers( # Check that the link capacity is not exceeded by the new load while count < path_node_list_length - 1: # Get the link between the next two nodes - edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] - ) + edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1]) link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth - if ( - link.get_current_load() + load - ) > link.get_bandwidth(): + if (link.get_current_load() + load) > link.get_bandwidth(): link_capacity_exceeded = True if _VERBOSE: print("Link capacity exceeded") @@ -226,9 +200,7 @@ def apply_iers( else: # One of the nodes is not operational if _VERBOSE: - print( - "Path not valid - one or more nodes not operational" - ) + print("Path not valid - one or more nodes not operational") pass else: @@ -243,9 +215,7 @@ def apply_iers( def apply_node_pol( nodes: Dict[str, NodeUnion], - node_pol: Dict[ - any, Union[NodeStateInstructionGreen, NodeStateInstructionRed] - ], + node_pol: Dict[any, Union[NodeStateInstructionGreen, NodeStateInstructionRed]], step: int, ): """ @@ -277,22 +247,16 @@ def apply_node_pol( elif node_pol_type == NodePOLType.OS: # Change OS state # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this - if isinstance(node, ActiveNode) or isinstance( - node, ServiceNode - ): + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): node.set_software_state_if_not_compromised(state) elif node_pol_type == NodePOLType.SERVICE: # Change a service state # Don't allow PoL to fix something that is compromised. Only the Blue agent can do this if isinstance(node, ServiceNode): - node.set_service_state_if_not_compromised( - service_name, state - ) + node.set_service_state_if_not_compromised(service_name, state) else: # Change the file system status - if isinstance(node, ActiveNode) or isinstance( - node, ServiceNode - ): + if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): node.set_file_system_state_if_not_compromised(state) else: # PoL is not valid in this time step diff --git a/src/primaite/pol/red_agent_pol.py b/src/primaite/pol/red_agent_pol.py index 96fe787c..bff19bf8 100644 --- a/src/primaite/pol/red_agent_pol.py +++ b/src/primaite/pol/red_agent_pol.py @@ -6,13 +6,7 @@ from networkx import MultiGraph, shortest_path from primaite.acl.access_control_list import AccessControlList from primaite.common.custom_typing import NodeUnion -from primaite.common.enums import ( - HardwareState, - NodePOLInitiator, - NodePOLType, - NodeType, - SoftwareState, -) +from primaite.common.enums import HardwareState, NodePOLInitiator, NodePOLType, NodeType, SoftwareState from primaite.links.link import Link from primaite.nodes.active_node import ActiveNode from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed @@ -83,10 +77,7 @@ def apply_red_agent_iers( if source_node.hardware_state == HardwareState.ON: if source_node.has_service(protocol): # Red agents IERs can only be valid if the source service is in a compromised state - if ( - source_node.get_service_state(protocol) - == SoftwareState.COMPROMISED - ): + if source_node.get_service_state(protocol) == SoftwareState.COMPROMISED: source_valid = True else: source_valid = False @@ -124,9 +115,7 @@ def apply_red_agent_iers( dest_valid = False # 3. Check that the ACL doesn't block it - acl_block = acl.is_blocked( - source_node.ip_address, dest_node.ip_address, protocol, port - ) + acl_block = acl.is_blocked(source_node.ip_address, dest_node.ip_address, protocol, port) if acl_block: if _VERBOSE: print( @@ -170,15 +159,11 @@ def apply_red_agent_iers( # Check that the link capacity is not exceeded by the new load while count < path_node_list_length - 1: # Get the link between the next two nodes - edge_dict = network.get_edge_data( - path_node_list[count], path_node_list[count + 1] - ) + edge_dict = network.get_edge_data(path_node_list[count], path_node_list[count + 1]) link_id = edge_dict[0].get("id") link = links[link_id] # Check whether the new load exceeds the bandwidth - if ( - link.get_current_load() + load - ) > link.get_bandwidth(): + if (link.get_current_load() + load) > link.get_bandwidth(): link_capacity_exceeded = True if _VERBOSE: print("Link capacity exceeded") @@ -203,23 +188,16 @@ def apply_red_agent_iers( # This IER is now valid, so set it to running ier_value.set_is_running(True) if _VERBOSE: - print( - "Red IER was allowed to run in step " - + str(step) - ) + print("Red IER was allowed to run in step " + str(step)) else: # One of the nodes is not operational if _VERBOSE: - print( - "Path not valid - one or more nodes not operational" - ) + print("Path not valid - one or more nodes not operational") pass else: if _VERBOSE: - print( - "Red IER was NOT allowed to run in step " + str(step) - ) + print("Red IER was NOT allowed to run in step " + str(step)) print("Source, Dest or ACL were not valid") pass # ------------------------------------ @@ -258,9 +236,7 @@ def apply_red_agent_node_pol( state = node_instruction.get_state() source_node_id = node_instruction.get_source_node_id() source_node_service_name = node_instruction.get_source_node_service() - source_node_service_state_value = ( - node_instruction.get_source_node_service_state() - ) + source_node_service_state_value = node_instruction.get_source_node_service_state() passed_checks = False @@ -274,9 +250,7 @@ def apply_red_agent_node_pol( passed_checks = True elif initiator == NodePOLInitiator.IER: # Need to check there is a red IER incoming - passed_checks = is_red_ier_incoming( - target_node, iers, pol_type - ) + passed_checks = is_red_ier_incoming(target_node, iers, pol_type) elif initiator == NodePOLInitiator.SERVICE: # Need to check the condition of a service on another node source_node = nodes[source_node_id] @@ -304,9 +278,7 @@ def apply_red_agent_node_pol( target_node.hardware_state = state elif pol_type == NodePOLType.OS: # Change OS state - if isinstance(target_node, ActiveNode) or isinstance( - target_node, ServiceNode - ): + if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): target_node.software_state = state elif pol_type == NodePOLType.SERVICE: # Change a service state @@ -314,15 +286,11 @@ def apply_red_agent_node_pol( target_node.set_service_state(service_name, state) else: # Change the file system status - if isinstance(target_node, ActiveNode) or isinstance( - target_node, ServiceNode - ): + if isinstance(target_node, ActiveNode) or isinstance(target_node, ServiceNode): target_node.set_file_system_state(state) else: if _VERBOSE: - print( - "Node Red Agent PoL not allowed - did not pass checks" - ) + print("Node Red Agent PoL not allowed - did not pass checks") else: # PoL is not valid in this time step pass @@ -337,10 +305,7 @@ def is_red_ier_incoming(node, iers, node_pol_type): node_id = node.node_id for ier_key, ier_value in iers.items(): - if ( - ier_value.get_is_running() - and ier_value.get_dest_node_id() == node_id - ): + if ier_value.get_is_running() and ier_value.get_dest_node_id() == node_id: if ( node_pol_type == NodePOLType.OPERATING or node_pol_type == NodePOLType.OS diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 4d8d3022..4ee6c507 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -1,7 +1,7 @@ from __future__ import annotations from pathlib import Path -from typing import Dict, Final, Optional, Union +from typing import Dict, Final, Union from primaite import getLogger from primaite.agents.agent import AgentSessionABC @@ -9,18 +9,8 @@ from primaite.agents.hardcoded_acl import HardCodedACLAgent from primaite.agents.hardcoded_node import HardCodedNodeAgent from primaite.agents.rllib import RLlibAgent from primaite.agents.sb3 import SB3Agent -from primaite.agents.simple import ( - DoNothingACLAgent, - DoNothingNodeAgent, - DummyAgent, - RandomAgent, -) -from primaite.common.enums import ( - ActionType, - AgentFramework, - AgentIdentifier, - SessionType, -) +from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyAgent, RandomAgent +from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType from primaite.config import lay_down_config, training_config from primaite.config.training_config import TrainingConfig @@ -49,16 +39,12 @@ class PrimaiteSession: if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) self._training_config_path: Final[Union[Path]] = training_config_path - self._training_config: Final[TrainingConfig] = training_config.load( - self._training_config_path - ) + self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) if not isinstance(lay_down_config_path, Path): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path - self._lay_down_config: Dict = lay_down_config.load( - self._lay_down_config_path - ) + self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) self._agent_session: AgentSessionABC = None # noqa self.session_path: Path = None # noqa @@ -69,28 +55,16 @@ class PrimaiteSession: def setup(self): """Performs the session setup.""" if self._training_config.agent_framework == AgentFramework.CUSTOM: - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}" - ) - if ( - self._training_config.agent_identifier - == AgentIdentifier.HARDCODED - ): - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Identifier =" - f" {AgentIdentifier.HARDCODED}" - ) + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.CUSTOM}") + if self._training_config.agent_identifier == AgentIdentifier.HARDCODED: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}") if self._training_config.action_type == ActionType.NODE: # Deterministic Hardcoded Agent with Node Action Space - self._agent_session = HardCodedNodeAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = HardCodedACLAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -100,24 +74,14 @@ class PrimaiteSession: # Invalid AgentIdentifier ActionType combo raise ValueError - elif ( - self._training_config.agent_identifier - == AgentIdentifier.DO_NOTHING - ): - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Identifier =" - f" {AgentIdentifier.DO_NOTHINGD}" - ) + elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHINGD}") if self._training_config.action_type == ActionType.NODE: - self._agent_session = DoNothingNodeAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path) elif self._training_config.action_type == ActionType.ACL: # Deterministic Hardcoded Agent with ACL Action Space - self._agent_session = DoNothingACLAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path) elif self._training_config.action_type == ActionType.ANY: # Deterministic Hardcoded Agent with ANY Action Space @@ -127,49 +91,26 @@ class PrimaiteSession: # Invalid AgentIdentifier ActionType combo raise ValueError - elif ( - self._training_config.agent_identifier - == AgentIdentifier.RANDOM - ): - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Identifier =" - f" {AgentIdentifier.RANDOM}" - ) - self._agent_session = RandomAgent( - self._training_config_path, self._lay_down_config_path - ) - elif ( - self._training_config.agent_identifier == AgentIdentifier.DUMMY - ): - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Identifier =" - f" {AgentIdentifier.DUMMY}" - ) - self._agent_session = DummyAgent( - self._training_config_path, self._lay_down_config_path - ) + elif self._training_config.agent_identifier == AgentIdentifier.RANDOM: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}") + self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path) + elif self._training_config.agent_identifier == AgentIdentifier.DUMMY: + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}") + self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path) else: # Invalid AgentFramework AgentIdentifier combo raise ValueError elif self._training_config.agent_framework == AgentFramework.SB3: - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}" - ) + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}") # Stable Baselines3 Agent - self._agent_session = SB3Agent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path) elif self._training_config.agent_framework == AgentFramework.RLLIB: - _LOGGER.debug( - f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}" - ) + _LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}") # Ray RLlib Agent - self._agent_session = RLlibAgent( - self._training_config_path, self._lay_down_config_path - ) + self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path) else: # Invalid AgentFramework @@ -182,35 +123,27 @@ class PrimaiteSession: def learn( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Train the agent. - :param time_steps: The number of time steps per episode. - :param episodes: The number of episodes. :param kwargs: Any agent-framework specific key word args. """ if not self._training_config.session_type == SessionType.EVAL: - self._agent_session.learn(time_steps, episodes, **kwargs) + self._agent_session.learn(**kwargs) def evaluate( self, - time_steps: Optional[int] = None, - episodes: Optional[int] = None, **kwargs, ): """ Evaluate the agent. - :param time_steps: The number of time steps per episode. - :param episodes: The number of episodes. :param kwargs: Any agent-framework specific key word args. """ if not self._training_config.session_type == SessionType.TRAIN: - self._agent_session.evaluate(time_steps, episodes, **kwargs) + self._agent_session.evaluate(**kwargs) def close(self): """Closes the agent.""" diff --git a/src/primaite/setup/reset_demo_notebooks.py b/src/primaite/setup/reset_demo_notebooks.py index 59eaf8cc..7fa96783 100644 --- a/src/primaite/setup/reset_demo_notebooks.py +++ b/src/primaite/setup/reset_demo_notebooks.py @@ -18,23 +18,17 @@ def run(overwrite_existing: bool = True): :param overwrite_existing: A bool to toggle replacing existing edited notebooks on or off. """ - notebooks_package_data_root = pkg_resources.resource_filename( - "primaite", "notebooks/_package_data" - ) + notebooks_package_data_root = pkg_resources.resource_filename("primaite", "notebooks/_package_data") for subdir, dirs, files in os.walk(notebooks_package_data_root): for file in files: fp = os.path.join(subdir, file) - path_split = os.path.relpath( - fp, notebooks_package_data_root - ).split(os.sep) + path_split = os.path.relpath(fp, notebooks_package_data_root).split(os.sep) target_fp = NOTEBOOKS_DIR / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() if overwrite_existing and not copy_file: - copy_file = (not filecmp.cmp(fp, target_fp)) and ( - ".ipynb_checkpoints" not in str(target_fp) - ) + copy_file = (not filecmp.cmp(fp, target_fp)) and (".ipynb_checkpoints" not in str(target_fp)) if copy_file: shutil.copy2(fp, target_fp) diff --git a/src/primaite/setup/reset_example_configs.py b/src/primaite/setup/reset_example_configs.py index f2b4a18f..5d62298c 100644 --- a/src/primaite/setup/reset_example_configs.py +++ b/src/primaite/setup/reset_example_configs.py @@ -17,16 +17,12 @@ def run(overwrite_existing=True): :param overwrite_existing: A bool to toggle replacing existing edited config on or off. """ - configs_package_data_root = pkg_resources.resource_filename( - "primaite", "config/_package_data" - ) + configs_package_data_root = pkg_resources.resource_filename("primaite", "config/_package_data") for subdir, dirs, files in os.walk(configs_package_data_root): for file in files: fp = os.path.join(subdir, file) - path_split = os.path.relpath(fp, configs_package_data_root).split( - os.sep - ) + path_split = os.path.relpath(fp, configs_package_data_root).split(os.sep) target_fp = USERS_CONFIG_DIR / "example_config" / Path(*path_split) target_fp.parent.mkdir(exist_ok=True, parents=True) copy_file = not target_fp.is_file() diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 1a71f0ff..eeafe05e 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -76,12 +76,8 @@ class Transaction(object): row = ( row + _turn_action_space_to_array(self.action_space) - + _turn_obs_space_to_array( - self.obs_space_pre, obs_assets, obs_features - ) - + _turn_obs_space_to_array( - self.obs_space_post, obs_assets, obs_features - ) + + _turn_obs_space_to_array(self.obs_space_pre, obs_assets, obs_features) + + _turn_obs_space_to_array(self.obs_space_post, obs_assets, obs_features) ) return header, row diff --git a/src/primaite/utils/session_output_writer.py b/src/primaite/utils/session_output_writer.py index 86c5ca28..a05b0453 100644 --- a/src/primaite/utils/session_output_writer.py +++ b/src/primaite/utils/session_output_writer.py @@ -51,9 +51,7 @@ class SessionOutputWriter: self._first_write: bool = True def _init_csv_writer(self): - self._csv_file = open( - self._csv_file_path, "w", encoding="UTF8", newline="" - ) + self._csv_file = open(self._csv_file_path, "w", encoding="UTF8", newline="") self._csv_writer = csv.writer(self._csv_file) diff --git a/tests/config/legacy/legacy_training_config.yaml b/tests/config/legacy_conversion/legacy_training_config.yaml similarity index 100% rename from tests/config/legacy/legacy_training_config.yaml rename to tests/config/legacy_conversion/legacy_training_config.yaml diff --git a/tests/config/legacy/new_training_config.yaml b/tests/config/legacy_conversion/new_training_config.yaml similarity index 100% rename from tests/config/legacy/new_training_config.yaml rename to tests/config/legacy_conversion/new_training_config.yaml diff --git a/tests/conftest.py b/tests/conftest.py index 41dc5e77..af76b314 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -57,8 +57,6 @@ class TempPrimaiteSession(PrimaiteSession): return self def __exit__(self, type, value, tb): - del self._agent_session._env.episode_av_reward_writer - del self._agent_session._env.transaction_writer shutil.rmtree(self.session_path) shutil.rmtree(self.session_path.parent) _LOGGER.debug(f"Deleted temp session directory: {self.session_path}") @@ -112,9 +110,7 @@ def temp_primaite_session(request): """ training_config_path = request.param[0] lay_down_config_path = request.param[1] - with patch( - "primaite.agents.agent.get_session_path", get_temp_session_path - ) as mck: + with patch("primaite.agents.agent.get_session_path", get_temp_session_path) as mck: mck.session_timestamp = datetime.now() return TempPrimaiteSession(training_config_path, lay_down_config_path) @@ -130,9 +126,7 @@ def temp_session_path() -> Path: session_timestamp = datetime.now() date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = ( - Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path - ) + session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) return session_path diff --git a/tests/mock_and_patch/get_session_path_mock.py b/tests/mock_and_patch/get_session_path_mock.py index cfcfb8f0..feff52f6 100644 --- a/tests/mock_and_patch/get_session_path_mock.py +++ b/tests/mock_and_patch/get_session_path_mock.py @@ -16,9 +16,7 @@ def get_temp_session_path(session_timestamp: datetime) -> Path: """ date_dir = session_timestamp.strftime("%Y-%m-%d") session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - session_path = ( - Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path - ) + session_path = Path(tempfile.gettempdir()) / "primaite" / date_dir / session_path session_path.mkdir(exist_ok=True, parents=True) _LOGGER.debug(f"Created temp session directory: {session_path}") return session_path diff --git a/tests/test_acl.py b/tests/test_acl.py index 260ccffc..30f12697 100644 --- a/tests/test_acl.py +++ b/tests/test_acl.py @@ -95,8 +95,6 @@ def test_rule_hash(): rule = ACLRule("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") hash_value_local = hash(rule) - hash_value_remote = acl.get_dictionary_hash( - "DENY", "192.168.1.1", "192.168.1.2", "TCP", "80" - ) + hash_value_remote = acl.get_dictionary_hash("DENY", "192.168.1.1", "192.168.1.2", "TCP", "80") assert hash_value_local == hash_value_remote diff --git a/tests/test_active_node.py b/tests/test_active_node.py index b6833182..addc595c 100644 --- a/tests/test_active_node.py +++ b/tests/test_active_node.py @@ -60,9 +60,7 @@ def test_os_state_change_if_not_compromised(operating_state, expected_state): 1, ) - active_node.set_software_state_if_not_compromised( - SoftwareState.OVERWHELMED - ) + active_node.set_software_state_if_not_compromised(SoftwareState.OVERWHELMED) assert active_node.software_state == expected_state @@ -100,9 +98,7 @@ def test_file_system_change(operating_state, expected_state): (HardwareState.ON, FileSystemState.CORRUPT), ], ) -def test_file_system_change_if_not_compromised( - operating_state, expected_state -): +def test_file_system_change_if_not_compromised(operating_state, expected_state): """ Test that a node cannot change its file system state. @@ -120,8 +116,6 @@ def test_file_system_change_if_not_compromised( 1, ) - active_node.set_file_system_state_if_not_compromised( - FileSystemState.CORRUPT - ) + active_node.set_file_system_state_if_not_compromised(FileSystemState.CORRUPT) assert active_node.file_system_state_actual == expected_state diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 21e4857f..d1082049 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -2,11 +2,7 @@ import numpy as np import pytest -from primaite.environment.observations import ( - NodeLinkTable, - NodeStatuses, - ObservationsHandler, -) +from primaite.environment.observations import NodeLinkTable, NodeStatuses, ObservationsHandler from tests import TEST_CONFIG_ROOT @@ -127,9 +123,7 @@ class TestNodeLinkTable: with temp_primaite_session as session: env = session.env # act = np.asarray([0,]) - obs, reward, done, info = env.step( - 0 - ) # apply the 'do nothing' action + obs, reward, done, info = env.step(0) # apply the 'do nothing' action assert np.array_equal( obs, @@ -192,17 +186,15 @@ class TestNodeStatuses: with temp_primaite_session as session: env = session.env obs, _, _, _ = env.step(0) # apply the 'do nothing' action - assert np.array_equal( - obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0] - ) + print(obs) + assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0]) @pytest.mark.parametrize( "temp_primaite_session", [ [ - TEST_CONFIG_ROOT - / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", + TEST_CONFIG_ROOT / "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml", TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", ] ], diff --git a/tests/test_primaite_session.py b/tests/test_primaite_session.py index 8c8d2b80..ae0b0870 100644 --- a/tests/test_primaite_session.py +++ b/tests/test_primaite_session.py @@ -36,18 +36,12 @@ def test_primaite_session(temp_primaite_session): # Check that both the transactions and av reward csv files exist for file in session.learning_path.iterdir(): if file.suffix == ".csv": - assert ( - "all_transactions" in file.name - or "average_reward_per_episode" in file.name - ) + assert "all_transactions" in file.name or "average_reward_per_episode" in file.name # Check that both the transactions and av reward csv files exist for file in session.evaluation_path.iterdir(): if file.suffix == ".csv": - assert ( - "all_transactions" in file.name - or "average_reward_per_episode" in file.name - ) + assert "all_transactions" in file.name or "average_reward_per_episode" in file.name _LOGGER.debug("Inspecting files in temp session path...") for dir_path, dir_names, file_names in os.walk(session_path): diff --git a/tests/test_resetting_node.py b/tests/test_resetting_node.py index e7312777..fb7dc83d 100644 --- a/tests/test_resetting_node.py +++ b/tests/test_resetting_node.py @@ -1,13 +1,7 @@ """Used to test Active Node functions.""" import pytest -from primaite.common.enums import ( - FileSystemState, - HardwareState, - NodeType, - Priority, - SoftwareState, -) +from primaite.common.enums import FileSystemState, HardwareState, NodeType, Priority, SoftwareState from primaite.common.service import Service from primaite.config.training_config import TrainingConfig from primaite.nodes.active_node import ActiveNode @@ -18,9 +12,7 @@ from primaite.nodes.service_node import ServiceNode "starting_operating_state, expected_operating_state", [(HardwareState.RESETTING, HardwareState.ON)], ) -def test_node_resets_correctly( - starting_operating_state, expected_operating_state -): +def test_node_resets_correctly(starting_operating_state, expected_operating_state): """Tests that a node resets correctly.""" active_node = ActiveNode( node_id="0", @@ -59,9 +51,7 @@ def test_node_boots_correctly(operating_state, expected_operating_state): file_system_state="GOOD", config_values=1, ) - service_attributes = Service( - name="node", port="80", software_state=SoftwareState.COMPROMISED - ) + service_attributes = Service(name="node", port="80", software_state=SoftwareState.COMPROMISED) service_node.add_service(service_attributes) for x in range(5): diff --git a/tests/test_service_node.py b/tests/test_service_node.py index 9e760b23..4383fc1b 100644 --- a/tests/test_service_node.py +++ b/tests/test_service_node.py @@ -45,9 +45,7 @@ def test_service_state_change(operating_state, expected_state): (HardwareState.ON, SoftwareState.OVERWHELMED), ], ) -def test_service_state_change_if_not_comprised( - operating_state, expected_state -): +def test_service_state_change_if_not_comprised(operating_state, expected_state): """ Test that a node cannot change the state of a running service. @@ -67,8 +65,6 @@ def test_service_state_change_if_not_comprised( service = Service("TCP", 80, SoftwareState.GOOD) service_node.add_service(service) - service_node.set_service_state_if_not_compromised( - "TCP", SoftwareState.OVERWHELMED - ) + service_node.set_service_state_if_not_compromised("TCP", SoftwareState.OVERWHELMED) assert service_node.get_service_state("TCP") == expected_state diff --git a/tests/test_single_action_space.py b/tests/test_single_action_space.py index 1cf63cde..5d55b9c9 100644 --- a/tests/test_single_action_space.py +++ b/tests/test_single_action_space.py @@ -18,7 +18,6 @@ def run_generic_set_actions(env: Primaite): # TEMP - random action for now # action = env.blue_agent_action(obs) action = 0 - print("Episode:", episode, "\nStep:", step) if step == 5: # [1, 1, 2, 1, 1, 1] # Creates an ACL rule @@ -86,8 +85,7 @@ def test_single_action_space_is_valid(temp_primaite_session): "temp_primaite_session", [ [ - TEST_CONFIG_ROOT - / "single_action_space_fixed_blue_actions_main_config.yaml", + TEST_CONFIG_ROOT / "single_action_space_fixed_blue_actions_main_config.yaml", TEST_CONFIG_ROOT / "single_action_space_lay_down_config.yaml", ] ], diff --git a/tests/test_training_config.py b/tests/test_training_config.py index 88bc802b..d7fe4e50 100644 --- a/tests/test_training_config.py +++ b/tests/test_training_config.py @@ -7,8 +7,8 @@ from tests import TEST_CONFIG_ROOT def test_legacy_lay_down_config_yaml_conversion(): """Tests the conversion of legacy lay down config files.""" - legacy_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml" - new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml" + legacy_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml" with open(legacy_path, "r") as file: legacy_dict = yaml.safe_load(file) @@ -16,9 +16,7 @@ def test_legacy_lay_down_config_yaml_conversion(): with open(new_path, "r") as file: new_dict = yaml.safe_load(file) - converted_dict = training_config.convert_legacy_training_config_dict( - legacy_dict - ) + converted_dict = training_config.convert_legacy_training_config_dict(legacy_dict) for key, value in new_dict.items(): assert converted_dict[key] == value @@ -26,13 +24,13 @@ def test_legacy_lay_down_config_yaml_conversion(): def test_create_config_values_main_from_file(): """Tests creating an instance of TrainingConfig from file.""" - new_path = TEST_CONFIG_ROOT / "legacy" / "new_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "new_training_config.yaml" training_config.load(new_path) def test_create_config_values_main_from_legacy_file(): """Tests creating an instance of TrainingConfig from legacy file.""" - new_path = TEST_CONFIG_ROOT / "legacy" / "legacy_training_config.yaml" + new_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml" training_config.load(new_path, legacy_file=True) From 16534237e0b3a697edd11f2e8928349328537a5b Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Fri, 30 Jun 2023 17:09:50 +0100 Subject: [PATCH 28/43] #917 - Dropped VerboseLevel in enums.py and changed OutputVerboseLevel to SB3OutputVerboseLevel --- src/primaite/agents/agent.py | 2 +- src/primaite/agents/sb3.py | 2 +- src/primaite/common/enums.py | 17 +++-------------- .../training/training_config_main.yaml | 8 ++++---- src/primaite/config/training_config.py | 15 ++++++--------- src/primaite/primaite_session.py | 2 +- 6 files changed, 16 insertions(+), 30 deletions(-) diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 90eb2b66..50939210 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -57,7 +57,7 @@ class AgentSessionABC(ABC): lay_down_config_path = Path(lay_down_config_path) self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) - self.output_verbose_level = self._training_config.output_verbose_level + self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level self._env: Primaite self._agent = None diff --git a/src/primaite/agents/sb3.py b/src/primaite/agents/sb3.py index 3161c93a..f5ac44cb 100644 --- a/src/primaite/agents/sb3.py +++ b/src/primaite/agents/sb3.py @@ -56,7 +56,7 @@ class SB3Agent(AgentSessionABC): self._agent = self._agent_class( PPOMlp, self._env, - verbose=self.output_verbose_level, + verbose=self.sb3_output_verbose_level, n_steps=self._training_config.num_steps, tensorboard_log=str(self._tensorboard_log_path), ) diff --git a/src/primaite/common/enums.py b/src/primaite/common/enums.py index a363a1a0..db5d153c 100644 --- a/src/primaite/common/enums.py +++ b/src/primaite/common/enums.py @@ -92,14 +92,6 @@ class SessionType(Enum): "Train then evaluate an agent" -class VerboseLevel(IntEnum): - """PrimAITE Session Output verbose level.""" - - NO_OUTPUT = 0 - INFO = 1 - DEBUG = 2 - - class AgentFramework(Enum): """The agent algorithm framework/package.""" @@ -199,12 +191,9 @@ class LinkStatus(Enum): OVERLOAD = 4 -class OutputVerboseLevel(IntEnum): - """The Agent output verbosity level.""" +class SB3OutputVerboseLevel(IntEnum): + """The Stable Baselines3 learn/eval output verbosity level.""" NONE = 0 - "No Output" INFO = 1 - "Info Messages" - ALL = 2 - "All Messages" + DEBUG = 2 diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 57793058..a414bed9 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -60,12 +60,12 @@ session_type: TRAIN_EVAL # The high value for the observation space observation_space_high_value: 1000000000 -# The Agent output verbosity level: +# The Stable Baselines3 learn/eval output verbosity level: # Options are: # "NONE" (No Output) -# "INFO" (Info Messages) -# "ALL" (All Messages) -output_verbose_level: NONE +# "INFO" (Info Messages (such as devices and wrappers used)) +# "DEBUG" (All Messages) +sb3_output_verbose_level: NONE # Reward values # Generic diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 3e0f26ca..2ffc2a8c 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -14,7 +14,7 @@ from primaite.common.enums import ( AgentIdentifier, DeepLearningFramework, HardCodedAgentView, - OutputVerboseLevel, + SB3OutputVerboseLevel, SessionType, ) @@ -86,8 +86,8 @@ class TrainingConfig: observation_space_high_value: int = 1000000000 "The high value for the observation space" - output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO - "The Agent output verbosity level" + sb3_output_verbose_level: SB3OutputVerboseLevel = SB3OutputVerboseLevel.NONE + "Stable Baselines3 learn/eval output verbosity level" # Reward values # Generic @@ -189,7 +189,7 @@ class TrainingConfig: "agent_identifier": AgentIdentifier, "action_type": ActionType, "session_type": SessionType, - "output_verbose_level": OutputVerboseLevel, + "sb3_output_verbose_level": SB3OutputVerboseLevel, "hard_coded_agent_view": HardCodedAgentView, } @@ -212,7 +212,7 @@ class TrainingConfig: data["deep_learning_framework"] = self.deep_learning_framework.name data["agent_identifier"] = self.agent_identifier.name data["action_type"] = self.action_type.name - data["output_verbose_level"] = self.output_verbose_level.name + data["sb3_output_verbose_level"] = self.sb3_output_verbose_level.name data["session_type"] = self.session_type.name data["hard_coded_agent_view"] = self.hard_coded_agent_view.name @@ -277,7 +277,6 @@ def convert_legacy_training_config_dict( agent_identifier: AgentIdentifier = AgentIdentifier.PPO, action_type: ActionType = ActionType.ANY, num_steps: int = 256, - output_verbose_level: OutputVerboseLevel = OutputVerboseLevel.INFO, ) -> Dict[str, Any]: """ Convert a legacy training config dict to the new format. @@ -291,8 +290,6 @@ def convert_legacy_training_config_dict( don't have action_type values. :param num_steps: The number of steps to set as legacy training configs don't have num_steps values. - :param output_verbose_level: The agent output verbose level to use as - legacy training configs don't have output_verbose_level values. :return: The converted training config dict. """ config_dict = { @@ -300,7 +297,7 @@ def convert_legacy_training_config_dict( "agent_identifier": agent_identifier.name, "action_type": action_type.name, "num_steps": num_steps, - "output_verbose_level": output_verbose_level.name, + "sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name, } session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"} legacy_config_dict["sessionType"] = session_type_map[legacy_config_dict["sessionType"]] diff --git a/src/primaite/primaite_session.py b/src/primaite/primaite_session.py index 4ee6c507..df3ebec1 100644 --- a/src/primaite/primaite_session.py +++ b/src/primaite/primaite_session.py @@ -75,7 +75,7 @@ class PrimaiteSession: raise ValueError elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING: - _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHINGD}") + _LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}") if self._training_config.action_type == ActionType.NODE: self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path) From ee94993344d8975336859845ed06364ff80c9a4e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 3 Jul 2023 08:00:51 +0000 Subject: [PATCH 29/43] Apply suggestions from code review --- src/primaite/environment/observations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index b19bd29f..81ddaaf5 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -29,7 +29,7 @@ class AbstractObservationComponent(ABC): self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? - self.structure: list[str] + self.structure: List[str] return NotImplemented @abstractmethod From 4299170ce42e68e392f5be2e1ef646c62546c971 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 09:46:52 +0100 Subject: [PATCH 30/43] #1522: added a check for existing links in laydown + test that checks if red agent instructions are random --- src/primaite/environment/primaite_env.py | 6 ++++ tests/test_red_random_agent_behaviour.py | 36 ++++++++++++------------ 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index e592e21f..58932c4c 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1265,6 +1265,12 @@ class Primaite(Env): # Bandwidth for all links bandwidths = [i.get_bandwidth() for i in list(self.links.values())] + + if len(bandwidths) < 1: + msg = "Random red agent cannot be used on a network without any links" + _LOGGER.error(msg) + raise Exception(msg) + servers = [node for node in node_list if node.node_type == NodeType.SERVER] for n, node in enumerate(nodes_to_be_compromised): diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index a86e32c1..c9189c26 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,5 +1,6 @@ -from datetime import time, datetime +from datetime import datetime +from primaite.config.lay_down_config import data_manipulation_config_path from primaite.environment.primaite_env import Primaite from tests import TEST_CONFIG_ROOT from tests.conftest import _get_temp_session_path @@ -24,9 +25,6 @@ def run_generic(env, config_values): if done: break - # Introduce a delay between steps - time.sleep(config_values.time_delay / 1000) - # Reset the environment at the end of the episode env.reset() @@ -40,6 +38,8 @@ def test_random_red_agent_behaviour(): When the initial state is OFF compared to reference state which is ON. """ list_of_node_instructions = [] + + # RUN TWICE so we can make sure that red agent is randomised for i in range(2): """Takes a config path and returns the created instance of Primaite.""" @@ -49,7 +49,7 @@ def test_random_red_agent_behaviour(): timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_lay_down_config.yaml", + lay_down_config_path=data_manipulation_config_path(), transaction_list=[], session_path=session_path, timestamp_str=timestamp_str, @@ -57,18 +57,18 @@ def test_random_red_agent_behaviour(): training_config = env.training_config training_config.num_steps = env.episode_steps - # TOOD: This needs t be refactored to happen outside. Should be part of - # a main Session class. - if training_config.agent_identifier == "GENERIC": - run_generic(env, training_config) - all_red_actions = env.red_node_pol - list_of_node_instructions.append(all_red_actions) + run_generic(env, training_config) + # add red pol instructions to list + list_of_node_instructions.append(env.red_node_pol) + + # compare instructions to make sure that red instructions are truly random + for index, instruction in enumerate(list_of_node_instructions): + for key in list_of_node_instructions[index].keys(): + instruction: NodeInstructionRed = list_of_node_instructions[index][key] + print(f"run {index}") + print(f"{key} start step: {instruction.get_start_step()}") + print(f"{key} end step: {instruction.get_end_step()}") + print(f"{key} target node id: {instruction.get_target_node_id()}") + print("") - # assert not (list_of_node_instructions[0].__eq__(list_of_node_instructions[1])) - print(list_of_node_instructions[0]["1"].get_start_step()) - print(list_of_node_instructions[0]["1"].get_end_step()) - print(list_of_node_instructions[0]["1"].get_target_node_id()) - print(list_of_node_instructions[1]["1"].get_start_step()) - print(list_of_node_instructions[1]["1"].get_end_step()) - print(list_of_node_instructions[1]["1"].get_target_node_id()) assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) From 6c4a538b41988869b5fa9d1f35515d969299e031 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 10:08:25 +0100 Subject: [PATCH 31/43] #1522: run pre-commit --- .gitignore | 2 +- src/primaite/environment/primaite_env.py | 55 ++++++++----------- .../nodes/node_state_instruction_red.py | 2 +- tests/test_red_random_agent_behaviour.py | 7 ++- 4 files changed, 28 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 5adbdc57..b65d1fd8 100644 --- a/.gitignore +++ b/.gitignore @@ -138,4 +138,4 @@ dmypy.json # Cython debug symbols cython_debug/ -.idea/ \ No newline at end of file +.idea/ diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 58932c4c..eb0bc5de 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -3,14 +3,14 @@ import copy import csv import logging +import uuid as uuid from datetime import datetime from pathlib import Path -from random import randint, choice, uniform, sample +from random import choice, randint, sample, uniform from typing import Dict, Tuple, Union import networkx as nx import numpy as np -import uuid as uuid import yaml from gym import Env, spaces from matplotlib import pyplot as plt @@ -60,12 +60,12 @@ class Primaite(Env): ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2 def __init__( - self, - training_config_path: Union[str, Path], - lay_down_config_path: Union[str, Path], - transaction_list, - session_path: Path, - timestamp_str: str, + self, + training_config_path: Union[str, Path], + lay_down_config_path: Union[str, Path], + transaction_list, + session_path: Path, + timestamp_str: str, ): """ The Primaite constructor. @@ -448,11 +448,11 @@ class Primaite(Env): elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 6 + len(self.action_dict[_action]) == 6 ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) elif ( - len(self.action_dict[_action]) == 4 + len(self.action_dict[_action]) == 4 ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: @@ -1238,7 +1238,6 @@ class Primaite(Env): def create_random_red_agent(self): """Decide on random red agent for the episode to be called in env.reset().""" - # Reset the current red iers and red node pol self.red_iers = {} self.red_node_pol = {} @@ -1260,7 +1259,7 @@ class Primaite(Env): # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1277,16 +1276,10 @@ class Primaite(Env): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = randint( - 2, max_step_compromised + 1 - ) # step compromised - pol_service_name = choice( - list(node.services.keys()) - ) + _start_step = randint(2, max_step_compromised + 1) # step compromised + pol_service_name = choice(list(node.services.keys())) - source_node_service = choice( - list(source_node.services.values()) - ) + source_node_service = choice(list(source_node.services.values())) red_pol = NodeStateInstructionRed( _id=_id, @@ -1299,7 +1292,7 @@ class Primaite(Env): _pol_state=SoftwareState.COMPROMISED, _pol_source_node_id=source_node.node_id, _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state + _pol_source_node_service_state=source_node_service.software_state, ) self.red_node_pol[_id] = red_pol @@ -1308,15 +1301,11 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint( - _start_step + 2, int(self.episode_steps * 0.8) - ) + ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith - ier_load = uniform(0.4, 0.8) * choice( - bandwidths - ) + ier_load = uniform(0.4, 0.8) * choice(bandwidths) ier_protocol = pol_service_name # Same protocol as compromised node ier_service = node.services[pol_service_name] ier_port = ier_service.port @@ -1335,10 +1324,10 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, + node.ip_address, + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: @@ -1376,6 +1365,6 @@ class Primaite(Env): _pol_state=SoftwareState.OVERWHELMED, _pol_source_node_id=source_node.node_id, _pol_source_node_service=source_node_service.name, - _pol_source_node_service_state=source_node_service.software_state + _pol_source_node_service_state=source_node_service.software_state, ) self.red_node_pol[o_pol_id] = o_red_pol diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 9ae917e9..2f7d0622 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -153,4 +153,4 @@ class NodeStateInstructionRed(object): f"source_node_service={self.source_node_service}, " f"source_node_service_state={self.source_node_service_state}" f")" - ) \ No newline at end of file + ) diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index c9189c26..476a08f1 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -2,6 +2,7 @@ from datetime import datetime from primaite.config.lay_down_config import data_manipulation_config_path from primaite.environment.primaite_env import Primaite +from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT from tests.conftest import _get_temp_session_path @@ -41,14 +42,14 @@ def test_random_red_agent_behaviour(): # RUN TWICE so we can make sure that red agent is randomised for i in range(2): - """Takes a config path and returns the created instance of Primaite.""" session_timestamp: datetime = datetime.now() session_path = _get_temp_session_path(session_timestamp) timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") env = Primaite( - training_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", + training_config_path=TEST_CONFIG_ROOT + / "one_node_states_on_off_main_config.yaml", lay_down_config_path=data_manipulation_config_path(), transaction_list=[], session_path=session_path, @@ -64,7 +65,7 @@ def test_random_red_agent_behaviour(): # compare instructions to make sure that red instructions are truly random for index, instruction in enumerate(list_of_node_instructions): for key in list_of_node_instructions[index].keys(): - instruction: NodeInstructionRed = list_of_node_instructions[index][key] + instruction: NodeStateInstructionRed = list_of_node_instructions[index][key] print(f"run {index}") print(f"{key} start step: {instruction.get_start_step()}") print(f"{key} end step: {instruction.get_end_step()}") From c38c13b82945d9b0343b3ef1ad7b0177435ba38f Mon Sep 17 00:00:00 2001 From: Christopher McCarthy Date: Mon, 3 Jul 2023 10:47:26 +0000 Subject: [PATCH 32/43] Apply suggestions from code review --- src/primaite/agents/agent.py | 2 +- src/primaite/agents/rllib.py | 2 +- src/primaite/environment/observations.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/primaite/agents/agent.py b/src/primaite/agents/agent.py index 50939210..685fe776 100644 --- a/src/primaite/agents/agent.py +++ b/src/primaite/agents/agent.py @@ -50,7 +50,7 @@ class AgentSessionABC(ABC): def __init__(self, training_config_path, lay_down_config_path): if not isinstance(training_config_path, Path): training_config_path = Path(training_config_path) - self._training_config_path: Final[Union[Path]] = training_config_path + self._training_config_path: Final[Union[Path, str]] = training_config_path self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path) if not isinstance(lay_down_config_path, Path): diff --git a/src/primaite/agents/rllib.py b/src/primaite/agents/rllib.py index 2b6a5a83..d851ba9c 100644 --- a/src/primaite/agents/rllib.py +++ b/src/primaite/agents/rllib.py @@ -55,7 +55,7 @@ class RLlibAgent(AgentSessionABC): msg = "Expected PPO or A2C agent_identifier, " f"got {self._training_config.agent_identifier.value}" _LOGGER.error(msg) raise ValueError(msg) - self._agent_config: PPOConfig + self._agent_config: Union[PPOConfig, A2CConfig] self._current_result: dict self._setup() diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index d0d5d46e..0470828e 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -57,7 +57,7 @@ class NodeLinkTable(AbstractObservationComponent): """ _FIXED_PARAMETERS: int = 4 - _MAX_VAL: int = 1_000_000 + _MAX_VAL: int = 1_000_000_000 _DATA_TYPE: type = np.int64 def __init__(self, env: "Primaite"): From c3ec33e4df44950b9ca2b0b58a2875edaf80dc2d Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 3 Jul 2023 12:03:36 +0100 Subject: [PATCH 33/43] #917 - Added Windows and MacOS to build pipeline. Updated so that runs only Python 3.8 and 3.10 (middle version not required) --- .azure/azure-ci-build-pipeline.yaml | 30 ++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 4c15daf5..a93aa131 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -6,16 +6,29 @@ trigger: - bugfix/* - release/* -pool: - vmImage: ubuntu-latest strategy: matrix: - Python38: + UbuntuPython38: python.version: '3.8' - Python39: - python.version: '3.9' - Python310: + imageName: 'ubuntu-latest' + UbuntuPython310: python.version: '3.10' + imageName: 'ubuntu-latest' + WindowsPython38: + python.version: '3.8' + imageName: 'windows-latest' + WindowsPython310: + python.version: '3.10' + imageName: 'windows-latest' + MacOSPython38: + python.version: '3.8' + imageName: 'macOS-latest' + MacOSPython310: + python.version: '3.10' + imageName: 'macOS-latest' + +pool: + vmImage: $(imageName) steps: - task: UsePythonVersion@0 @@ -50,11 +63,6 @@ steps: primaite setup displayName: 'Perform PrimAITE Setup' -#- script: | -# flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics -# flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics -# displayName: 'Lint with flake8' - - script: | pytest tests/ displayName: 'Run tests' From 0943e9511b55074f6f2721312231c840e2f243cb Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 12:18:58 +0100 Subject: [PATCH 34/43] #1522: refactor red_agent_identifier -> random_red_agent so that it is a boolean + documentation --- docs/source/config.rst | 4 + .../training/training_config_main.yaml | 2 +- .../training_config_random_red_agent.yaml | 2 +- src/primaite/config/training_config.py | 2 +- src/primaite/environment/primaite_env.py | 2 +- tests/config/random_agent_main_config.yaml | 96 ------------------- tests/test_red_random_agent_behaviour.py | 2 + 7 files changed, 10 insertions(+), 100 deletions(-) delete mode 100644 tests/config/random_agent_main_config.yaml diff --git a/docs/source/config.rst b/docs/source/config.rst index 74898ec1..fa58e6cf 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -28,6 +28,10 @@ The environment config file consists of the following attributes: * STABLE_BASELINES3_PPO - Use a SB3 PPO agent * STABLE_BASELINES3_A2C - use a SB3 A2C agent +* **random_red_agent** [bool] + + Determines if the session should be run with a random red agent + * **action_type** [enum] Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session diff --git a/src/primaite/config/_package_data/training/training_config_main.yaml b/src/primaite/config/_package_data/training/training_config_main.yaml index 3fe668e2..8f035d41 100644 --- a/src/primaite/config/_package_data/training/training_config_main.yaml +++ b/src/primaite/config/_package_data/training/training_config_main.yaml @@ -9,7 +9,7 @@ agent_identifier: STABLE_BASELINES3_A2C # RED AGENT IDENTIFIER # RANDOM or NONE -red_agent_identifier: "NONE" +random_red_agent: False # Sets How the Action Space is defined: # "NODE" diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml index 9382a2b5..3e0a3e2f 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -9,7 +9,7 @@ agent_identifier: STABLE_BASELINES3_A2C # RED AGENT IDENTIFIER # RANDOM or NONE -red_agent_identifier: "RANDOM" +random_red_agent: True # Sets How the Action Space is defined: # "NODE" diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 6e88e7cb..7995dfe8 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -21,7 +21,7 @@ class TrainingConfig: agent_identifier: str = "STABLE_BASELINES3_A2C" "The Red Agent algo/class to be used." - red_agent_identifier: str = "RANDOM" + random_red_agent: bool = False "Creates Random Red Agent Attacks" action_type: ActionType = ActionType.ANY diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index eb0bc5de..5cb85afd 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -277,7 +277,7 @@ class Primaite(Env): self.reset_environment() # Create a random red agent to use for this episode - if self.training_config.red_agent_identifier == "RANDOM": + if self.training_config.random_red_agent: self.create_random_red_agent() # Reset counters and totals diff --git a/tests/config/random_agent_main_config.yaml b/tests/config/random_agent_main_config.yaml deleted file mode 100644 index d2d18bbc..00000000 --- a/tests/config/random_agent_main_config.yaml +++ /dev/null @@ -1,96 +0,0 @@ -# Main Config File - -# Generic config values -# Choose one of these (dependent on Agent being trained) -# "STABLE_BASELINES3_PPO" -# "STABLE_BASELINES3_A2C" -# "GENERIC" -agent_identifier: GENERIC -# -red_agent_identifier: RANDOM -# Sets How the Action Space is defined: -# "NODE" -# "ACL" -# "ANY" node and acl actions -action_type: ANY -# Number of episodes to run per session -num_episodes: 1 -# Number of time_steps per episode -num_steps: 5 -# Time delay between steps (for generic agents) -time_delay: 1 -# Type of session to be run (TRAINING or EVALUATION) -session_type: TRAINING -# Determine whether to load an agent from file -load_agent: False -# File path and file name of agent if you're loading one in -agent_load_file: C:\[Path]\[agent_saved_filename.zip] - -# Environment config values -# The high value for the observation space -observation_space_high_value: 1_000_000_000 - -# Reward values -# Generic -all_ok: 0 -# Node Hardware State -off_should_be_on: -10 -off_should_be_resetting: -5 -on_should_be_off: -2 -on_should_be_resetting: -5 -resetting_should_be_on: -5 -resetting_should_be_off: -2 -resetting: -3 -# Node Software or Service State -good_should_be_patching: 2 -good_should_be_compromised: 5 -good_should_be_overwhelmed: 5 -patching_should_be_good: -5 -patching_should_be_compromised: 2 -patching_should_be_overwhelmed: 2 -patching: -3 -compromised_should_be_good: -20 -compromised_should_be_patching: -20 -compromised_should_be_overwhelmed: -20 -compromised: -20 -overwhelmed_should_be_good: -20 -overwhelmed_should_be_patching: -20 -overwhelmed_should_be_compromised: -20 -overwhelmed: -20 -# Node File System State -good_should_be_repairing: 2 -good_should_be_restoring: 2 -good_should_be_corrupt: 5 -good_should_be_destroyed: 10 -repairing_should_be_good: -5 -repairing_should_be_restoring: 2 -repairing_should_be_corrupt: 2 -repairing_should_be_destroyed: 0 -repairing: -3 -restoring_should_be_good: -10 -restoring_should_be_repairing: -2 -restoring_should_be_corrupt: 1 -restoring_should_be_destroyed: 2 -restoring: -6 -corrupt_should_be_good: -10 -corrupt_should_be_repairing: -10 -corrupt_should_be_restoring: -10 -corrupt_should_be_destroyed: 2 -corrupt: -10 -destroyed_should_be_good: -20 -destroyed_should_be_repairing: -20 -destroyed_should_be_restoring: -20 -destroyed_should_be_corrupt: -20 -destroyed: -20 -scanning: -2 -# IER status -red_ier_running: -5 -green_ier_blocked: -10 - -# Patching / Reset durations -os_patching_duration: 5 # The time taken to patch the OS -node_reset_duration: 5 # The time taken to reset a node (hardware) -service_patching_duration: 5 # The time taken to patch a service -file_system_repairing_limit: 5 # The time take to repair the file system -file_system_restoring_limit: 5 # The time take to restore the file system -file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 476a08f1..6b06dbb1 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -55,6 +55,8 @@ def test_random_red_agent_behaviour(): session_path=session_path, timestamp_str=timestamp_str, ) + # set red_agent_ + env.training_config.random_red_agent = True training_config = env.training_config training_config.num_steps = env.episode_steps From cb9d40579f4c5352350e52bfd66b615b35559ae5 Mon Sep 17 00:00:00 2001 From: Czar Echavez Date: Mon, 3 Jul 2023 13:36:14 +0100 Subject: [PATCH 35/43] #1522: create_random_red_agent -> _create_random_red_agent + converting NodeStateInstructionRed into a dataclass --- src/primaite/environment/primaite_env.py | 4 ++-- .../nodes/node_state_instruction_red.py | 20 +++---------------- 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 5cb85afd..823c11fe 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -278,7 +278,7 @@ class Primaite(Env): # Create a random red agent to use for this episode if self.training_config.random_red_agent: - self.create_random_red_agent() + self._create_random_red_agent() # Reset counters and totals self.total_reward = 0 @@ -1236,7 +1236,7 @@ class Primaite(Env): combined_action_dict = {**acl_action_dict, **new_node_action_dict} return combined_action_dict - def create_random_red_agent(self): + def _create_random_red_agent(self): """Decide on random red agent for the episode to be called in env.reset().""" # Reset the current red iers and red node pol self.red_iers = {} diff --git a/src/primaite/nodes/node_state_instruction_red.py b/src/primaite/nodes/node_state_instruction_red.py index 2f7d0622..4272ce24 100644 --- a/src/primaite/nodes/node_state_instruction_red.py +++ b/src/primaite/nodes/node_state_instruction_red.py @@ -1,8 +1,11 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Defines node behaviour for Green PoL.""" +from dataclasses import dataclass + from primaite.common.enums import NodePOLType +@dataclass() class NodeStateInstructionRed(object): """The Node State Instruction class.""" @@ -137,20 +140,3 @@ class NodeStateInstructionRed(object): The source node service state """ return self.source_node_service_state - - def __repr__(self): - return ( - f"{self.__class__.__name__}(" - f"id={self.id}, " - f"start_step={self.start_step}, " - f"end_step={self.end_step}, " - f"target_node_id={self.target_node_id}, " - f"initiator={self.initiator}, " - f"pol_type={self.pol_type}, " - f"service_name={self.service_name}, " - f"state={self.state}, " - f"source_node_id={self.source_node_id}, " - f"source_node_service={self.source_node_service}, " - f"source_node_service_state={self.source_node_service_state}" - f")" - ) From 94ca28a85f93aa12514859fd9307ced4eed7a24e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 3 Jul 2023 12:37:08 +0000 Subject: [PATCH 36/43] Add windows build option --- .azure/azure-ci-build-pipeline.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index a93aa131..91c113a7 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -55,8 +55,11 @@ steps: displayName: 'Build PrimAITE' - script: | +- ${{ if eq( variables['Agent.OS'], 'Linux') }}: PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) python -m pip install $PRIMAITE_WHEEL[dev] +- ${{ elseif eq( variable['Agent.OS'], 'Windows_NT') }}: + forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file" displayName: 'Install PrimAITE' - script: | From 63a4c1119b453ac27d65ab8764c148bca6d3790a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 3 Jul 2023 12:40:02 +0000 Subject: [PATCH 37/43] Updated azure-ci-build-pipeline.yaml --- .azure/azure-ci-build-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 91c113a7..5ce13919 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -58,7 +58,7 @@ steps: - ${{ if eq( variables['Agent.OS'], 'Linux') }}: PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) python -m pip install $PRIMAITE_WHEEL[dev] -- ${{ elseif eq( variable['Agent.OS'], 'Windows_NT') }}: +- ${{ elseif eq( variables['Agent.OS'], 'Windows_NT') }}: forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file" displayName: 'Install PrimAITE' From 8101f49a2144fee5074277adb8799f6ff18d4b03 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 3 Jul 2023 12:44:01 +0000 Subject: [PATCH 38/43] Updated azure-ci-build-pipeline.yaml --- .azure/azure-ci-build-pipeline.yaml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 5ce13919..06e44f48 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -55,12 +55,15 @@ steps: displayName: 'Build PrimAITE' - script: | -- ${{ if eq( variables['Agent.OS'], 'Linux') }}: PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) python -m pip install $PRIMAITE_WHEEL[dev] -- ${{ elseif eq( variables['Agent.OS'], 'Windows_NT') }}: + displayName: 'Install PrimAITE' + condition: eq( variables['Agent.OS'], 'Linux' ) + +- script: | forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file" displayName: 'Install PrimAITE' + condition: eq( variables['Agent.OS'], 'Windows_NT' ) - script: | primaite setup From f47dd8bf61f4ae1cdf7b55176179d64c6feed176 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 3 Jul 2023 13:36:33 +0000 Subject: [PATCH 39/43] Updated azure-ci-build-pipeline.yaml --- .azure/azure-ci-build-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 06e44f48..244887a1 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -58,7 +58,7 @@ steps: PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl) python -m pip install $PRIMAITE_WHEEL[dev] displayName: 'Install PrimAITE' - condition: eq( variables['Agent.OS'], 'Linux' ) + condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - script: | forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file" From 7ddedfcc57b1d6f4e746d928da9881ecdd8cd333 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 3 Jul 2023 16:02:59 +0000 Subject: [PATCH 40/43] Updated azure-ci-build-pipeline.yaml --- .azure/azure-ci-build-pipeline.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.azure/azure-ci-build-pipeline.yaml b/.azure/azure-ci-build-pipeline.yaml index 244887a1..902eb38d 100644 --- a/.azure/azure-ci-build-pipeline.yaml +++ b/.azure/azure-ci-build-pipeline.yaml @@ -61,7 +61,7 @@ steps: condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' )) - script: | - forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file" + forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]" displayName: 'Install PrimAITE' condition: eq( variables['Agent.OS'], 'Windows_NT' ) From 7816e94f832b7dd61a57540bee1f9a0eda96a726 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 3 Jul 2023 17:25:21 +0100 Subject: [PATCH 41/43] #917 - Synced with dev (at the point of random red agent) --- .../training_config_random_red_agent.yaml | 6 +- src/primaite/environment/primaite_env.py | 62 ++++++---- tests/config/test_random_red_main_config.yaml | 112 ++++++++++++++++++ tests/test_red_random_agent_behaviour.py | 76 +++--------- 4 files changed, 172 insertions(+), 84 deletions(-) create mode 100644 tests/config/test_random_red_main_config.yaml diff --git a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml index 3e0a3e2f..96243daf 100644 --- a/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml +++ b/src/primaite/config/_package_data/training/training_config_random_red_agent.yaml @@ -7,8 +7,10 @@ # "GENERIC" agent_identifier: STABLE_BASELINES3_A2C -# RED AGENT IDENTIFIER -# RANDOM or NONE +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False random_red_agent: True # Sets How the Action Space is defined: diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index c80c36ec..36632155 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -427,7 +427,12 @@ class Primaite(Env): for link_key, link_value in self.links.items(): _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - _LOGGER.debug(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) + print( + " Protocol: " + + protocol.get_name().name + + ", Load: " + + str(protocol.get_load()) + ) def interpret_action_and_apply(self, _action): """ @@ -437,16 +442,21 @@ class Primaite(Env): _action: The action space from the agent """ # At the moment, actions are only affecting nodes + if self.training_config.action_type == ActionType.NODE: self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) - elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6 + elif ( + len(self.action_dict[_action]) == 6 + ): # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) - elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4 + elif ( + len(self.action_dict[_action]) == 4 + ): # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: - _LOGGER.error("Invalid action type found") + logging.error("Invalid action type found") def apply_actions_to_nodes(self, _action): """ @@ -510,8 +520,7 @@ class Primaite(Env): elif property_action == 1: # Patch (valid action if it's good or compromised) node.set_service_state( - self.services_list[service_index], - SoftwareState.PATCHING, + self.services_list[service_index], SoftwareState.PATCHING ) else: # Node is not of Service Type @@ -709,7 +718,8 @@ class Primaite(Env): _LOGGER.error(f"Invalid item_type: {item_type}") pass - _LOGGER.debug("Environment configuration loaded") + _LOGGER.info("Environment configuration loaded") + print("Environment configuration loaded") def create_node(self, item): """ @@ -1166,12 +1176,7 @@ class Primaite(Env): # Use MAX to ensure we get them all for node_action in range(4): for service_state in range(self.num_services): - action = [ - node, - node_property, - node_action, - service_state, - ] + action = [node, node_property, node_action, service_state] # check to see if it's a nothing action (has no effect) if is_valid_node_action(action): actions[action_key] = action @@ -1221,7 +1226,11 @@ class Primaite(Env): # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other - new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0} + new_node_action_dict = { + k + len(acl_action_dict) - 1: v + for k, v in node_action_dict.items() + if k != 0 + } # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} @@ -1235,7 +1244,8 @@ class Primaite(Env): # Decide how many nodes become compromised node_list = list(self.nodes.values()) - computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + computers = [node for node in node_list if + node.node_type == NodeType.COMPUTER] max_num_nodes_compromised = len( computers ) # only computers can become compromised @@ -1250,7 +1260,7 @@ class Primaite(Env): # For each of the nodes to be compromised decide which step they become compromised max_step_compromised = ( - self.episode_steps // 2 + self.episode_steps // 2 ) # always compromise in first half of episode # Bandwidth for all links @@ -1261,13 +1271,15 @@ class Primaite(Env): _LOGGER.error(msg) raise Exception(msg) - servers = [node for node in node_list if node.node_type == NodeType.SERVER] + servers = [node for node in node_list if + node.node_type == NodeType.SERVER] for n, node in enumerate(nodes_to_be_compromised): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = randint(2, max_step_compromised + 1) # step compromised + _start_step = randint(2, + max_step_compromised + 1) # step compromised pol_service_name = choice(list(node.services.keys())) source_node_service = choice(list(source_node.services.values())) @@ -1292,7 +1304,8 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) + ier_start_step = randint(_start_step + 2, + int(self.episode_steps * 0.8)) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith @@ -1315,15 +1328,16 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, + node.ip_address, + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: # If still none found choose from all servers - possible_ier_destinations = [server.node_id for server in servers] + possible_ier_destinations = [server.node_id for server in + servers] ier_dest = choice(possible_ier_destinations) self.red_iers[ier_id] = IER( ier_id, diff --git a/tests/config/test_random_red_main_config.yaml b/tests/config/test_random_red_main_config.yaml new file mode 100644 index 00000000..800fe808 --- /dev/null +++ b/tests/config/test_random_red_main_config.yaml @@ -0,0 +1,112 @@ +# Training Config File + +# Sets which agent algorithm framework will be used. +# Options are: +# "SB3" (Stable Baselines3) +# "RLLIB" (Ray RLlib) +# "CUSTOM" (Custom Agent) +agent_framework: CUSTOM + +# Sets which Agent class will be used. +# Options are: +# "A2C" (Advantage Actor Critic coupled with either SB3 or RLLIB agent_framework) +# "PPO" (Proximal Policy Optimization coupled with either SB3 or RLLIB agent_framework) +# "HARDCODED" (The HardCoded agents coupled with an ACL or NODE action_type) +# "DO_NOTHING" (The DoNothing agents coupled with an ACL or NODE action_type) +# "RANDOM" (primaite.agents.simple.RandomAgent) +# "DUMMY" (primaite.agents.simple.DummyAgent) +agent_identifier: DUMMY + +# Sets whether Red Agent POL and IER is randomised. +# Options are: +# True +# False +random_red_agent: True + +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: NODE +# Number of episodes to run per session +num_episodes: 2 +# Number of time_steps per episode +num_steps: 15 +# Time delay between steps (for generic agents) +time_delay: 1 + +# Type of session to be run (TRAINING or EVALUATION) +session_type: EVAL +# Determine whether to load an agent from file +load_agent: False +# File path and file name of agent if you're loading one in +agent_load_file: C:\[Path]\[agent_saved_filename.zip] + +# Environment config values +# The high value for the observation space +observation_space_high_value: 1000000000 + +# Reward values +# Generic +all_ok: 0 +# Node Hardware State +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 +resetting: -3 +# Node Software or Service State +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 +patching: -3 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 +compromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 +overwhelmed: -20 +# Node File System State +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 +repairing: -3 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 +restoring: -6 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 +corrupt: -10 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 +destroyed: -20 +scanning: -2 +# IER status +red_ier_running: -5 +green_ier_blocked: -10 + +# Patching / Reset durations +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 6b06dbb1..8cf60236 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -1,68 +1,29 @@ -from datetime import datetime +import pytest from primaite.config.lay_down_config import data_manipulation_config_path -from primaite.environment.primaite_env import Primaite from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT -from tests.conftest import _get_temp_session_path - -def run_generic(env, config_values): - """Run against a generic agent.""" - # Reset the environment at the start of the episode - env.reset() - for episode in range(0, config_values.num_episodes): - for step in range(0, config_values.num_steps): - # Send the observation space to the agent to get an action - # TEMP - random action for now - # action = env.blue_agent_action(obs) - # action = env.action_space.sample() - action = 0 - - # Run the simulation step on the live environment - obs, reward, done, info = env.step(action) - - # Break if done is True - if done: - break - - # Reset the environment at the end of the episode - env.reset() - - env.close() - - -def test_random_red_agent_behaviour(): - """ - Test that hardware state is penalised at each step. - - When the initial state is OFF compared to reference state which is ON. - """ +@pytest.mark.parametrize( + "temp_primaite_session", + [ + [ + TEST_CONFIG_ROOT / "test_random_red_main_config.yaml", + data_manipulation_config_path(), + ] + ], + indirect=True, +) +def test_random_red_agent_behaviour(temp_primaite_session): + """Test that red agent POL is randomised each episode.""" list_of_node_instructions = [] - # RUN TWICE so we can make sure that red agent is randomised - for i in range(2): - """Takes a config path and returns the created instance of Primaite.""" - session_timestamp: datetime = datetime.now() - session_path = _get_temp_session_path(session_timestamp) + with temp_primaite_session as session: + session.evaluate() + list_of_node_instructions.append(session.env.red_node_pol) - timestamp_str = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S") - env = Primaite( - training_config_path=TEST_CONFIG_ROOT - / "one_node_states_on_off_main_config.yaml", - lay_down_config_path=data_manipulation_config_path(), - transaction_list=[], - session_path=session_path, - timestamp_str=timestamp_str, - ) - # set red_agent_ - env.training_config.random_red_agent = True - training_config = env.training_config - training_config.num_steps = env.episode_steps - - run_generic(env, training_config) - # add red pol instructions to list - list_of_node_instructions.append(env.red_node_pol) + session.evaluate() + list_of_node_instructions.append(session.env.red_node_pol) # compare instructions to make sure that red instructions are truly random for index, instruction in enumerate(list_of_node_instructions): @@ -73,5 +34,4 @@ def test_random_red_agent_behaviour(): print(f"{key} end step: {instruction.get_end_step()}") print(f"{key} target node id: {instruction.get_target_node_id()}") print("") - assert list_of_node_instructions[0].__ne__(list_of_node_instructions[1]) From 410d5abe12b4cd47afff86e4f46aaa9a0fa6a68a Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 3 Jul 2023 20:36:21 +0100 Subject: [PATCH 42/43] #917 - Synced with dev and integrated the new observation space --- src/primaite/environment/primaite_env.py | 17 +++- src/primaite/transactions/transaction.py | 57 +++++------- .../transactions/transactions_to_file.py | 91 ------------------- 3 files changed, 36 insertions(+), 129 deletions(-) delete mode 100644 src/primaite/transactions/transactions_to_file.py diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 75bf7310..d7b68045 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -178,6 +178,9 @@ class Primaite(Env): # It will be initialised later. self.obs_handler: ObservationsHandler + self._obs_space_description = None + "The env observation space description for transactions writing" + # Open the config file and build the environment laydown with open(self._lay_down_config_path, "r") as file: # Open the config file and build the environment laydown @@ -318,9 +321,16 @@ class Primaite(Env): link.clear_traffic() # Create a Transaction (metric) object for this step - transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) + transaction = Transaction( + self.agent_identifier, + self.actual_episode_count, + self.step_count + ) # Load the initial observation space into the transaction - transaction.set_obs_space(self.obs_handler._flat_observation) + transaction.obs_space = self.obs_handler._flat_observation + + # Set the transaction obs space description + transaction.obs_space_description = self._obs_space_description # Load the action space into the transaction transaction.action_space = copy.deepcopy(action) @@ -675,6 +685,9 @@ class Primaite(Env): """ self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) + if not self._obs_space_description: + self._obs_space_description = self.obs_handler.describe_structure() + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 69f0f545..763dc458 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -3,11 +3,18 @@ from datetime import datetime from typing import List, Tuple +from primaite.common.enums import AgentIdentifier + class Transaction(object): """Transaction class.""" - def __init__(self, agent_identifier, episode_number, step_number): + def __init__( + self, + agent_identifier: AgentIdentifier, + episode_number: int, + step_number: int + ): """ Transaction constructor. @@ -17,11 +24,14 @@ class Transaction(object): """ self.timestamp = datetime.now() "The datetime of the transaction" - self.agent_identifier = agent_identifier - self.episode_number = episode_number + self.agent_identifier: AgentIdentifier = agent_identifier + "The agent identifier" + self.episode_number: int = episode_number "The episode number" - self.step_number = step_number + self.step_number: int = step_number "The step number" + self.obs_space = None + "The observation space (pre)" self.obs_space_pre = None "The observation space before any actions are taken" self.obs_space_post = None @@ -30,16 +40,8 @@ class Transaction(object): "The reward value" self.action_space = None "The action space invoked by the agent" - - def set_obs_space(self, _obs_space): - """ - Sets the observation space (pre). - - Args: - _obs_space_pre: The observation space before any actions are taken - """ - self.obs_space = _obs_space - + self.obs_space_description = None + "The env observation space description" def as_csv_data(self) -> Tuple[List, List]: """ @@ -51,32 +53,16 @@ class Transaction(object): action_length = self.action_space else: action_length = self.action_space.size - obs_shape = self.obs_space_post.shape - obs_assets = self.obs_space_post.shape[0] - if len(obs_shape) == 1: - # A bit of a workaround but I think the way transactions are - # written will change soon - obs_features = 1 - else: - obs_features = self.obs_space_post.shape[1] # Create the action space headers array action_header = [] for x in range(action_length): action_header.append("AS_" + str(x)) - # Create the observation space headers array - obs_header_initial = [] - obs_header_new = [] - for x in range(obs_assets): - for y in range(obs_features): - obs_header_initial.append("OSI_" + str(x) + "_" + str(y)) - obs_header_new.append("OSN_" + str(x) + "_" + str(y)) - # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_header_initial + obs_header_new - + header = header + action_header + self.obs_space_description + row = [ str(self.timestamp), str(self.episode_number), @@ -84,10 +70,9 @@ class Transaction(object): str(self.reward), ] row = ( - row - + _turn_action_space_to_array(self.action_space) - + _turn_obs_space_to_array(self.obs_space_pre, obs_assets, obs_features) - + _turn_obs_space_to_array(self.obs_space_post, obs_assets, obs_features) + row + + _turn_action_space_to_array(self.action_space) + + self.obs_space.tolist() ) return header, row diff --git a/src/primaite/transactions/transactions_to_file.py b/src/primaite/transactions/transactions_to_file.py deleted file mode 100644 index 4e364f0b..00000000 --- a/src/primaite/transactions/transactions_to_file.py +++ /dev/null @@ -1,91 +0,0 @@ -# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. -"""Writes the Transaction log list out to file for evaluation to utilse.""" - -import csv -from pathlib import Path - -from primaite import getLogger - -_LOGGER = getLogger(__name__) - - -def turn_action_space_to_array(_action_space): - """ - Turns action space into a string array so it can be saved to csv. - - Args: - _action_space: The action space. - """ - if isinstance(_action_space, list): - return [str(i) for i in _action_space] - else: - return [str(_action_space)] - - -def write_transaction_to_file( - transaction_list, - session_path: Path, - timestamp_str: str, - obs_space_description: list, -): - """ - Writes transaction logs to file to support training evaluation. - - :param transaction_list: The list of transactions from all steps and all - episodes. - :param session_path: The directory path the session is writing to. - :param timestamp_str: The session timestamp in the format: - _. - """ - # Get the first transaction and use it to determine the makeup of the - # observation space and action space - # Label the obs space fields in csv as "OSI_1_1", "OSN_1_1" and action - # space as "AS_1" - # This will be tied into the PrimAITE Use Case so that they make sense - template_transation = transaction_list[0] - action_length = template_transation.action_space.size - # obs_shape = template_transation.obs_space_post.shape - # obs_assets = template_transation.obs_space_post.shape[0] - # if len(obs_shape) == 1: - # bit of a workaround but I think the way transactions are written will change soon - # obs_features = 1 - # else: - # obs_features = template_transation.obs_space_post.shape[1] - - # Create the action space headers array - action_header = [] - for x in range(action_length): - action_header.append("AS_" + str(x)) - - # Create the observation space headers array - # obs_header_initial = [f"pre_{o}" for o in obs_space_description] - # obs_header_new = [f"post_{o}" for o in obs_space_description] - - # Open up a csv file - header = ["Timestamp", "Episode", "Step", "Reward"] - header = header + action_header + obs_space_description - - try: - filename = session_path / f"all_transactions_{timestamp_str}.csv" - _LOGGER.debug(f"Saving transaction logs: {filename}") - csv_file = open(filename, "w", encoding="UTF8", newline="") - csv_writer = csv.writer(csv_file) - csv_writer.writerow(header) - - for transaction in transaction_list: - csv_data = [ - str(transaction.timestamp), - str(transaction.episode_number), - str(transaction.step_number), - str(transaction.reward), - ] - csv_data = ( - csv_data - + turn_action_space_to_array(transaction.action_space) - + transaction.obs_space.tolist() - ) - csv_writer.writerow(csv_data) - - csv_file.close() - except Exception: - _LOGGER.error("Could not save the transaction file", exc_info=True) From 34b294f89a5224207d3154d79232e525a49731e0 Mon Sep 17 00:00:00 2001 From: Chris McCarthy Date: Mon, 3 Jul 2023 20:40:38 +0100 Subject: [PATCH 43/43] #917 - Reinstalled the pre-commit hook --- src/primaite/environment/observations.py | 4 +- src/primaite/environment/primaite_env.py | 71 +++++++----------------- src/primaite/transactions/transaction.py | 15 +---- tests/test_red_random_agent_behaviour.py | 1 + 4 files changed, 24 insertions(+), 67 deletions(-) diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 511fb008..23bc4a39 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -266,9 +266,7 @@ class NodeStatuses(AbstractObservationComponent): for service in services: structure.append(f"node_{node_id}_service_{service}_state_NONE") for state in SoftwareState: - structure.append( - f"node_{node_id}_service_{service}_state_{state.name}" - ) + structure.append(f"node_{node_id}_service_{service}_state_{state.name}") return structure diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index d7b68045..03c23f93 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -1,14 +1,11 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. """Main environment module containing the PRIMmary AI Training Evironment (Primaite) class.""" import copy -import csv import logging import uuid as uuid -from datetime import datetime from pathlib import Path -from typing import Dict, Final, Tuple, Union from random import choice, randint, sample, uniform -from typing import Dict, Tuple, Union +from typing import Dict, Final, Tuple, Union import networkx as nx import numpy as np @@ -321,11 +318,7 @@ class Primaite(Env): link.clear_traffic() # Create a Transaction (metric) object for this step - transaction = Transaction( - self.agent_identifier, - self.actual_episode_count, - self.step_count - ) + transaction = Transaction(self.agent_identifier, self.actual_episode_count, self.step_count) # Load the initial observation space into the transaction transaction.obs_space = self.obs_handler._flat_observation @@ -436,12 +429,7 @@ class Primaite(Env): for link_key, link_value in self.links.items(): _LOGGER.debug("Link ID: " + link_value.get_id()) for protocol in link_value.protocol_list: - print( - " Protocol: " - + protocol.get_name().name - + ", Load: " - + str(protocol.get_load()) - ) + print(" Protocol: " + protocol.get_name().name + ", Load: " + str(protocol.get_load())) def interpret_action_and_apply(self, _action): """ @@ -456,13 +444,9 @@ class Primaite(Env): self.apply_actions_to_nodes(_action) elif self.training_config.action_type == ActionType.ACL: self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 6 - ): # ACL actions in multidiscrete form have len 6 + elif len(self.action_dict[_action]) == 6: # ACL actions in multidiscrete form have len 6 self.apply_actions_to_acl(_action) - elif ( - len(self.action_dict[_action]) == 4 - ): # Node actions in multdiscrete (array) from have len 4 + elif len(self.action_dict[_action]) == 4: # Node actions in multdiscrete (array) from have len 4 self.apply_actions_to_nodes(_action) else: logging.error("Invalid action type found") @@ -528,9 +512,7 @@ class Primaite(Env): return elif property_action == 1: # Patch (valid action if it's good or compromised) - node.set_service_state( - self.services_list[service_index], SoftwareState.PATCHING - ) + node.set_service_state(self.services_list[service_index], SoftwareState.PATCHING) else: # Node is not of Service Type return @@ -1238,11 +1220,7 @@ class Primaite(Env): # Change node keys to not overlap with acl keys # Only 1 nothing action (key 0) is required, remove the other - new_node_action_dict = { - k + len(acl_action_dict) - 1: v - for k, v in node_action_dict.items() - if k != 0 - } + new_node_action_dict = {k + len(acl_action_dict) - 1: v for k, v in node_action_dict.items() if k != 0} # Combine the Node dict and ACL dict combined_action_dict = {**acl_action_dict, **new_node_action_dict} @@ -1256,11 +1234,8 @@ class Primaite(Env): # Decide how many nodes become compromised node_list = list(self.nodes.values()) - computers = [node for node in node_list if - node.node_type == NodeType.COMPUTER] - max_num_nodes_compromised = len( - computers - ) # only computers can become compromised + computers = [node for node in node_list if node.node_type == NodeType.COMPUTER] + max_num_nodes_compromised = len(computers) # only computers can become compromised # random select between 1 and max_num_nodes_compromised num_nodes_to_compromise = randint(1, max_num_nodes_compromised) @@ -1271,9 +1246,7 @@ class Primaite(Env): source_node = choice(nodes_to_be_compromised) # For each of the nodes to be compromised decide which step they become compromised - max_step_compromised = ( - self.episode_steps // 2 - ) # always compromise in first half of episode + max_step_compromised = self.episode_steps // 2 # always compromise in first half of episode # Bandwidth for all links bandwidths = [i.get_bandwidth() for i in list(self.links.values())] @@ -1283,15 +1256,13 @@ class Primaite(Env): _LOGGER.error(msg) raise Exception(msg) - servers = [node for node in node_list if - node.node_type == NodeType.SERVER] + servers = [node for node in node_list if node.node_type == NodeType.SERVER] for n, node in enumerate(nodes_to_be_compromised): # 1: Use Node PoL to set node to compromised _id = str(uuid.uuid4()) - _start_step = randint(2, - max_step_compromised + 1) # step compromised + _start_step = randint(2, max_step_compromised + 1) # step compromised pol_service_name = choice(list(node.services.keys())) source_node_service = choice(list(source_node.services.values())) @@ -1316,8 +1287,7 @@ class Primaite(Env): ier_id = str(uuid.uuid4()) # Launch the attack after node is compromised, and not right at the end of the episode - ier_start_step = randint(_start_step + 2, - int(self.episode_steps * 0.8)) + ier_start_step = randint(_start_step + 2, int(self.episode_steps * 0.8)) ier_end_step = self.episode_steps # Randomise the load, as a percentage of a random link bandwith @@ -1325,9 +1295,7 @@ class Primaite(Env): ier_protocol = pol_service_name # Same protocol as compromised node ier_service = node.services[pol_service_name] ier_port = ier_service.port - ier_mission_criticality = ( - 0 # Red IER will never be important to green agent success - ) + ier_mission_criticality = 0 # Red IER will never be important to green agent success # We choose a node to attack based on the first that applies: # a. Green IERs, select dest node of the red ier based on dest node of green IER # b. Attack a random server that doesn't have a DENY acl rule in default config @@ -1340,16 +1308,15 @@ class Primaite(Env): if len(possible_ier_destinations) < 1: for server in servers: if not self.acl.is_blocked( - node.ip_address, - server.ip_address, - ier_service, - ier_port, + node.ip_address, + server.ip_address, + ier_service, + ier_port, ): possible_ier_destinations.append(server.node_id) if len(possible_ier_destinations) < 1: # If still none found choose from all servers - possible_ier_destinations = [server.node_id for server in - servers] + possible_ier_destinations = [server.node_id for server in servers] ier_dest = choice(possible_ier_destinations) self.red_iers[ier_id] = IER( ier_id, diff --git a/src/primaite/transactions/transaction.py b/src/primaite/transactions/transaction.py index 763dc458..7db2444a 100644 --- a/src/primaite/transactions/transaction.py +++ b/src/primaite/transactions/transaction.py @@ -9,12 +9,7 @@ from primaite.common.enums import AgentIdentifier class Transaction(object): """Transaction class.""" - def __init__( - self, - agent_identifier: AgentIdentifier, - episode_number: int, - step_number: int - ): + def __init__(self, agent_identifier: AgentIdentifier, episode_number: int, step_number: int): """ Transaction constructor. @@ -62,18 +57,14 @@ class Transaction(object): # Open up a csv file header = ["Timestamp", "Episode", "Step", "Reward"] header = header + action_header + self.obs_space_description - + row = [ str(self.timestamp), str(self.episode_number), str(self.step_number), str(self.reward), ] - row = ( - row - + _turn_action_space_to_array(self.action_space) - + self.obs_space.tolist() - ) + row = row + _turn_action_space_to_array(self.action_space) + self.obs_space.tolist() return header, row diff --git a/tests/test_red_random_agent_behaviour.py b/tests/test_red_random_agent_behaviour.py index 8cf60236..f8885f3e 100644 --- a/tests/test_red_random_agent_behaviour.py +++ b/tests/test_red_random_agent_behaviour.py @@ -4,6 +4,7 @@ from primaite.config.lay_down_config import data_manipulation_config_path from primaite.nodes.node_state_instruction_red import NodeStateInstructionRed from tests import TEST_CONFIG_ROOT + @pytest.mark.parametrize( "temp_primaite_session", [