temp commit

This commit is contained in:
Chris McCarthy
2023-06-13 09:42:54 +01:00
parent dc0349c37b
commit 40686031e6
11 changed files with 626 additions and 173 deletions

View File

@@ -22,46 +22,64 @@ The environment config file consists of the following attributes:
* **agent_identifier** [enum]
This identifies the agent to use for the session. Select from one of the following:
This identifies the agent to use for the session. Select from one of the following:
* GENERIC - Where a user developed agent is to be used
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
* GENERIC - Where a user developed agent is to be used
* STABLE_BASELINES3_PPO - Use a SB3 PPO agent
* STABLE_BASELINES3_A2C - use a SB3 A2C agent
* **agent_framework** [enum]
This identifies the agent framework to be used to instantiate the agent algorithm. Select from one of the following:
* NONE - Where a user developed agent is to be used
* SB3 - Stable Baselines3
* RLLIB - Ray RLlib.
* **red_agent_identifier**
This identifies the agent to use for the session. Select from one of the following:
* A2C - Advantage Actor Critic
* PPO - Proximal Policy Optimization
* HARDCODED - A custom built deterministic agent
* RANDOM - A Stochastic random agent
* **action_type** [enum]
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
Determines whether a NODE, ACL, or ANY (combined NODE & ACL) action space format is adopted for the session
* **num_episodes** [int]
This defines the number of episodes that the agent will train or be evaluated over.
This defines the number of episodes that the agent will train or be evaluated over.
* **num_steps** [int]
Determines the number of steps to run in each episode of the session
Determines the number of steps to run in each episode of the session
* **time_delay** [int]
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
The time delay (in milliseconds) to take between each step when running a GENERIC agent session
* **session_type** [text]
Type of session to be run (TRAINING or EVALUATION)
Type of session to be run (TRAINING, EVALUATION, or BOTH)
* **load_agent** [bool]
Determine whether to load an agent from file
Determine whether to load an agent from file
* **agent_load_file** [text]
File path and file name of agent if you're loading one in
File path and file name of agent if you're loading one in
* **observation_space_high_value** [int]
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
The high value to use for values in the observation space. This is set to 1000000000 by default, and should not need changing in most cases
**Reward-Based Config Values**
@@ -69,95 +87,95 @@ Rewards are calculated based on the difference between the current state and ref
* **Generic [all_ok]** [int]
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
The score to give when the current situation (for a given component) is no different from that expected in the baseline (i.e. as though no blue or red agent actions had been undertaken)
* **Node Hardware State [off_should_be_on]** [int]
The score to give when the node should be on, but is off
The score to give when the node should be on, but is off
* **Node Hardware State [off_should_be_resetting]** [int]
The score to give when the node should be resetting, but is off
The score to give when the node should be resetting, but is off
* **Node Hardware State [on_should_be_off]** [int]
The score to give when the node should be off, but is on
The score to give when the node should be off, but is on
* **Node Hardware State [on_should_be_resetting]** [int]
The score to give when the node should be resetting, but is on
The score to give when the node should be resetting, but is on
* **Node Hardware State [resetting_should_be_on]** [int]
The score to give when the node should be on, but is resetting
The score to give when the node should be on, but is resetting
* **Node Hardware State [resetting_should_be_off]** [int]
The score to give when the node should be off, but is resetting
The score to give when the node should be off, but is resetting
* **Node Hardware State [resetting]** [int]
The score to give when the node is resetting
The score to give when the node is resetting
* **Node Operating System or Service State [good_should_be_patching]** [int]
The score to give when the state should be patching, but is good
The score to give when the state should be patching, but is good
* **Node Operating System or Service State [good_should_be_compromised]** [int]
The score to give when the state should be compromised, but is good
The score to give when the state should be compromised, but is good
* **Node Operating System or Service State [good_should_be_overwhelmed]** [int]
The score to give when the state should be overwhelmed, but is good
The score to give when the state should be overwhelmed, but is good
* **Node Operating System or Service State [patching_should_be_good]** [int]
The score to give when the state should be good, but is patching
The score to give when the state should be good, but is patching
* **Node Operating System or Service State [patching_should_be_compromised]** [int]
The score to give when the state should be compromised, but is patching
The score to give when the state should be compromised, but is patching
* **Node Operating System or Service State [patching_should_be_overwhelmed]** [int]
The score to give when the state should be overwhelmed, but is patching
The score to give when the state should be overwhelmed, but is patching
* **Node Operating System or Service State [patching]** [int]
The score to give when the state is patching
The score to give when the state is patching
* **Node Operating System or Service State [compromised_should_be_good]** [int]
The score to give when the state should be good, but is compromised
The score to give when the state should be good, but is compromised
* **Node Operating System or Service State [compromised_should_be_patching]** [int]
The score to give when the state should be patching, but is compromised
The score to give when the state should be patching, but is compromised
* **Node Operating System or Service State [compromised_should_be_overwhelmed]** [int]
The score to give when the state should be overwhelmed, but is compromised
The score to give when the state should be overwhelmed, but is compromised
* **Node Operating System or Service State [compromised]** [int]
The score to give when the state is compromised
The score to give when the state is compromised
* **Node Operating System or Service State [overwhelmed_should_be_good]** [int]
The score to give when the state should be good, but is overwhelmed
The score to give when the state should be good, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_patching]** [int]
The score to give when the state should be patching, but is overwhelmed
The score to give when the state should be patching, but is overwhelmed
* **Node Operating System or Service State [overwhelmed_should_be_compromised]** [int]
The score to give when the state should be compromised, but is overwhelmed
The score to give when the state should be compromised, but is overwhelmed
* **Node Operating System or Service State [overwhelmed]** [int]
The score to give when the state is overwhelmed
The score to give when the state is overwhelmed
* **Node File System State [good_should_be_repairing]** [int]
@@ -261,37 +279,37 @@ Rewards are calculated based on the difference between the current state and ref
* **IER Status [red_ier_running]** [int]
The score to give when a red agent IER is permitted to run
The score to give when a red agent IER is permitted to run
* **IER Status [green_ier_blocked]** [int]
The score to give when a green agent IER is prevented from running
The score to give when a green agent IER is prevented from running
**Patching / Reset Durations**
* **os_patching_duration** [int]
The number of steps to take when patching an Operating System
The number of steps to take when patching an Operating System
* **node_reset_duration** [int]
The number of steps to take when resetting a node's hardware state
The number of steps to take when resetting a node's hardware state
* **service_patching_duration** [int]
The number of steps to take when patching a service
The number of steps to take when patching a service
* **file_system_repairing_limit** [int]:
The number of steps to take when repairing the file system
The number of steps to take when repairing the file system
* **file_system_restoring_limit** [int]
The number of steps to take when restoring the file system
The number of steps to take when restoring the file system
* **file_system_scanning_limit** [int]
The number of steps to take when scanning the file system
The number of steps to take when scanning the file system
The Lay Down Config
*******************
@@ -300,22 +318,22 @@ The lay down config file consists of the following attributes:
* **itemType: ACTIONS** [enum]
Determines whether a NODE or ACL action space format is adopted for the session
Determines whether a NODE or ACL action space format is adopted for the session
* **itemType: OBSERVATION_SPACE** [dict]
Allows for user to configure observation space by combining one or more observation components. List of available
components is is :py:mod:'primaite.environment.observations'.
Allows for user to configure observation space by combining one or more observation components. List of available
components is is :py:mod:'primaite.environment.observations'.
The observation space config item should have a ``components`` key which is a list of components. Each component
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
component while it is being initialised.
The observation space config item should have a ``components`` key which is a list of components. Each component
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
component while it is being initialised.
This example illustrates the correct format for the observation space config item
This example illustrates the correct format for the observation space config item
.. code-block::yaml
- itemType: OBSERVATION_SPACE
- item_type: OBSERVATION_SPACE
components:
- name: LINK_TRAFFIC_LEVELS
options:
@@ -328,15 +346,15 @@ The lay down config file consists of the following attributes:
* **item_type: PORTS** [int]
Provides a list of ports modelled in this session
Provides a list of ports modelled in this session
* **item_type: SERVICES** [freetext]
Provides a list of services modelled in this session
Provides a list of services modelled in this session
* **item_type: NODE**
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
Defines a node included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item
* **name** [freetext]: Human-readable name of the component
@@ -355,7 +373,7 @@ The lay down config file consists of the following attributes:
* **item_type: LINK**
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
Defines a link included in the system laydown being simulated. It should consist of the following attributes:
* **id** [int]: Unique ID for this YAML item
* **name** [freetext]: Human-readable name of the component
@@ -365,7 +383,7 @@ The lay down config file consists of the following attributes:
* **item_type: GREEN_IER**
Defines a green agent Information Exchange Requirement (IER). It should consist of:
Defines a green agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this IER to begin
@@ -379,7 +397,7 @@ The lay down config file consists of the following attributes:
* **item_type: RED_IER**
Defines a red agent Information Exchange Requirement (IER). It should consist of:
Defines a red agent Information Exchange Requirement (IER). It should consist of:
* **id** [int]: Unique ID for this YAML item
* **start_step** [int]: The start step (in the episode) for this IER to begin

View File

@@ -32,6 +32,7 @@ dependencies = [
"numpy==1.23.5",
"platformdirs==3.5.1",
"PyYAML==6.0",
"ray[rllib]==2.2.0",
"stable-baselines3==1.6.2",
"typer[all]==0.9.0"
]

View File

@@ -0,0 +1,36 @@
from abc import ABC, abstractmethod
from typing import Optional
from primaite.environment.primaite_env import Primaite
class AgentABC(ABC):
@abstractmethod
def __init__(self, env: Primaite):
self._env: Primaite = env
self._agent = None
@abstractmethod
def _setup(self):
pass
@abstractmethod
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
pass
@abstractmethod
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
pass
@abstractmethod
def load(self):
pass
@abstractmethod
def save(self):
pass
@abstractmethod
def export(self):
pass

View File

@@ -0,0 +1,177 @@
import glob
import time
from enum import Enum
from pathlib import Path
from typing import Union, Optional
from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from primaite.config import training_config
from primaite.environment.primaite_env import Primaite
class DLFramework(Enum):
"""The DL Frameworks enumeration."""
TF = "tf"
TF2 = "tf2"
TORCH = "torch"
def env_creator(env_config):
training_config_path = env_config["training_config_path"]
lay_down_config_path = env_config["lay_down_config_path"]
return Primaite(training_config_path, lay_down_config_path, [])
def get_ppo_config(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
framework: Optional[DLFramework] = DLFramework.TORCH
) -> PPOConfig():
# Register environment
register_env("primaite", env_creator)
# Setup PPO
config = PPOConfig()
config_values = training_config.load(training_config_path)
# Setup our config object to use our environment
config.environment(
env="primaite",
env_config=dict(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path
)
)
env_config = config_values
action_type = env_config.action_type
red_agent = env_config.red_agent_identifier
if red_agent == "RANDOM" and action_type == "NODE":
config.training(
train_batch_size=6000, lr=5e-5
) # number of steps in a training iteration
elif red_agent == "RANDOM" and action_type != "NODE":
config.training(train_batch_size=6000, lr=5e-5)
elif red_agent == "CONFIG" and action_type == "NODE":
config.training(train_batch_size=400, lr=5e-5)
elif red_agent == "CONFIG" and action_type != "NONE":
config.training(train_batch_size=500, lr=5e-5)
else:
config.training(train_batch_size=500, lr=5e-5)
# Decide if you want torch or tensorflow DL framework. Default is "tf"
config.framework(framework=framework.value)
# Set the log level to DEBUG, INFO, WARN, or ERROR
config.debugging(seed=415, log_level="ERROR")
# Setup evaluation
# Explicitly set "explore"=False to override default
# config.evaluation(
# evaluation_interval=100,
# evaluation_duration=20,
# # evaluation_duration_unit="timesteps",) #default episodes
# evaluation_config={"explore": False},
# )
# Setup sampling rollout workers
config.rollouts(
num_rollout_workers=4,
num_envs_per_worker=1,
horizon=128, # num parralel workiers
) # max num steps in an episode
config.build() # Build config
return config
def train(
num_iterations: int,
config: Optional[PPOConfig] = None,
algo: Optional[Algorithm] = None
):
"""
Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint)
"""
start_time = time.time()
if algo is None:
algo = config.build()
elif config is None:
config = algo.get_config()
print(f"Algorithm type: {type(algo)}")
# iterations are not the same as episodes.
for i in range(num_iterations):
result = algo.train()
# # Save every 10 iterations or after last iteration in training
# if (i % 100 == 0) or (i == num_iterations - 1):
print(
f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}")
# save checkpoint file
checkpoint_file = algo.save("./")
print(f"Checkpoint saved at {checkpoint_file}")
# convert num_iterations to num_episodes
num_episodes = len(
result["hist_stats"]["episode_lengths"]) * num_iterations
# convert num_iterations to num_timesteps
num_timesteps = sum(
result["hist_stats"]["episode_lengths"] * num_iterations)
# calculate number of wins
# train time
print(f"Training took {time.time() - start_time:.2f} seconds")
print(
f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}")
return result
def load_model_from_checkpoint(config, checkpoint=None):
# create an empty Algorithm
algo = config.build()
if checkpoint is None:
# Get the checkpoint with the highest iteration number
checkpoint = get_most_recent_checkpoint(config)
# restore the agent from the checkpoint
algo.restore(checkpoint)
return algo
def get_most_recent_checkpoint(config):
"""
Get the most recent checkpoint for specified action type, red agent and algorithm
"""
env_config = list(config.env_config.values())[0]
action_type = env_config.action_type
red_agent = env_config.red_agent_identifier
algo_name = config.algo_class.__name__
# Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model
relevant_checkpoints = glob.glob(
f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*"
)
checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints]
max_checkpoint = str(max(checkpoint_numbers))
checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint
checkpoint = (
relevant_checkpoints[0].split("_")[0]
+ "_"
+ checkpoint_number_to_use
+ "/rllib_checkpoint.json"
)
return checkpoint

View File

@@ -0,0 +1,28 @@
# from typing import Optional
#
# from primaite.agents.agent_abc import AgentABC
# from primaite.environment.primaite_env import Primaite
#
#
# class SB3PPO(AgentABC):
# def __init__(self, env: Primaite):
# super().__init__(env)
#
# def _setup(self):
# if self._env.training_config
# pass
#
# def learn(self, time_steps: Optional[int], episodes: Optional[int]):
# pass
#
# def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
# pass
#
# def load(self):
# pass
#
# def save(self):
# pass
#
# def export(self):
# pass

View File

@@ -79,6 +79,33 @@ class Protocol(Enum):
NONE = 7
class SessionType(Enum):
"The type of PrimAITE Session to be run."
TRAINING = 1
EVALUATION = 2
BOTH = 3
class VerboseLevel(Enum):
"""PrimAITE Session Output verbose level."""
NO_OUTPUT = 0
INFO = 1
DEBUG = 2
class AgentFramework(Enum):
NONE = 0
SB3 = 1
RLLIB = 2
class RedAgentIdentifier(Enum):
A2C = 1
PPO = 2
HARDCODED = 3
RANDOM = 4
class ActionType(Enum):
"""Action type enumeration."""

View File

@@ -1,95 +0,0 @@
# # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
# """The config class."""
# from dataclasses import dataclass
#
# from primaite.common.enums import ActionType
#
#
# @dataclass()
# class TrainingConfig:
# """Class to hold main config values."""
#
# # Generic
# agent_identifier: str # The Red Agent algo/class to be used
# action_type: ActionType # type of action to use (NODE/ACL/ANY)
# num_episodes: int # number of episodes to train over
# num_steps: int # number of steps in an episode
# time_delay: int # delay between steps (ms) - applies to generic agents only
# # file
# session_type: str # the session type to run (TRAINING or EVALUATION)
# load_agent: str # Determine whether to load an agent from file
# agent_load_file: str # File path and file name of agent if you're loading one in
#
# # Environment
# observation_space_high_value: int # The high value for the observation space
#
# # Reward values
# # Generic
# all_ok: int
# # Node Hardware State
# off_should_be_on: int
# off_should_be_resetting: int
# on_should_be_off: int
# on_should_be_resetting: int
# resetting_should_be_on: int
# resetting_should_be_off: int
# resetting: int
# # Node Software or Service State
# good_should_be_patching: int
# good_should_be_compromised: int
# good_should_be_overwhelmed: int
# patching_should_be_good: int
# patching_should_be_compromised: int
# patching_should_be_overwhelmed: int
# patching: int
# compromised_should_be_good: int
# compromised_should_be_patching: int
# compromised_should_be_overwhelmed: int
# compromised: int
# overwhelmed_should_be_good: int
# overwhelmed_should_be_patching: int
# overwhelmed_should_be_compromised: int
# overwhelmed: int
# # Node File System State
# good_should_be_repairing: int
# good_should_be_restoring: int
# good_should_be_corrupt: int
# good_should_be_destroyed: int
# repairing_should_be_good: int
# repairing_should_be_restoring: int
# repairing_should_be_corrupt: int
# repairing_should_be_destroyed: int # Repairing does not fix destroyed state - you need to restore
#
# repairing: int
# restoring_should_be_good: int
# restoring_should_be_repairing: int
# restoring_should_be_corrupt: int # Not the optimal method (as repair will fix corruption)
#
# restoring_should_be_destroyed: int
# restoring: int
# corrupt_should_be_good: int
# corrupt_should_be_repairing: int
# corrupt_should_be_restoring: int
# corrupt_should_be_destroyed: int
# corrupt: int
# destroyed_should_be_good: int
# destroyed_should_be_repairing: int
# destroyed_should_be_restoring: int
# destroyed_should_be_corrupt: int
# destroyed: int
# scanning: int
# # IER status
# red_ier_running: int
# green_ier_blocked: int
#
# # Patching / Reset
# os_patching_duration: int # The time taken to patch the OS
# node_reset_duration: int # The time taken to reset a node (hardware)
# node_booting_duration = 0 # The Time taken to turn on the node
# node_shutdown_duration = 0 # The time taken to turn off the node
# service_patching_duration: int # The time taken to patch a service
# file_system_repairing_limit: int # The time take to repair a file
# file_system_restoring_limit: int # The time take to restore a file
# file_system_scanning_limit: int # The time taken to scan the file system
# # Patching / Reset
#

View File

@@ -2,6 +2,8 @@
from pathlib import Path
from typing import Final
import networkx
from primaite import USERS_CONFIG_DIR, getLogger
_LOGGER = getLogger(__name__)
@@ -9,6 +11,12 @@ _LOGGER = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
# class LayDownConfig:
# network: networkx.Graph
# POL
# EIR
# ACL
def ddos_basic_one_config_path() -> Path:
"""
The path to the example lay_down_config_1_DDOS_basic.yaml file.

View File

@@ -1,4 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Union, Optional
@@ -6,7 +8,8 @@ from typing import Any, Dict, Final, Union, Optional
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import ActionType
from primaite.common.enums import ActionType, RedAgentIdentifier, \
AgentFramework, SessionType
_LOGGER = getLogger(__name__)
@@ -16,10 +19,11 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
@dataclass()
class TrainingConfig:
"""The Training Config class."""
agent_framework: AgentFramework = AgentFramework.SB3
"The agent framework."
# Generic
agent_identifier: str = "STABLE_BASELINES3_A2C"
"The Red Agent algo/class to be used."
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
"The red agent/algo class."
action_type: ActionType = ActionType.ANY
"The ActionType to use."
@@ -38,8 +42,8 @@ class TrainingConfig:
"The delay between steps (ms). Applies to generic agents only."
# file
session_type: str = "TRAINING"
"the session type to run (TRAINING or EVALUATION)"
session_type: SessionType = SessionType.TRAINING
"The type of PrimAITE session to run."
load_agent: str = False
"Determine whether to load an agent from file."
@@ -137,6 +141,24 @@ class TrainingConfig:
file_system_scanning_limit: int = 5
"The time taken to scan the file system."
@classmethod
def from_dict(
cls,
config_dict: Dict[str, Union[str, int, bool]]
) -> TrainingConfig:
field_enum_map = {
"agent_framework": AgentFramework,
"red_agent_identifier": RedAgentIdentifier,
"action_type": ActionType,
"session_type": SessionType
}
for field, enum_class in field_enum_map.items():
if field in config_dict:
config_dict[field] = enum_class[field]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True):
"""
Serialise the ``TrainingConfig`` as dict.
@@ -196,10 +218,8 @@ def load(file_path: Union[str, Path],
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
# Convert values to Enums
config["action_type"] = ActionType[config["action_type"]]
try:
return TrainingConfig(**config)
return TrainingConfig.from_dict(**config)
except TypeError as e:
msg = (
f"Error when creating an instance of {TrainingConfig} "
@@ -214,22 +234,30 @@ def load(file_path: Union[str, Path],
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
num_steps: int = 256,
action_type: str = "ANY"
agent_framework: AgentFramework = AgentFramework.SB3,
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:param agent_framework: The agent framework to use as legacy training
configs don't have agent_framework values.
:param red_agent_identifier: The red agent identifier to use as legacy
training configs don't have red_agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:return: The converted training config dict.
"""
config_dict = {
"num_steps": num_steps,
"action_type": action_type
"agent_framework": agent_framework.name,
"red_agent_identifier": red_agent_identifier.name,
"action_type": action_type.name,
"num_steps": num_steps
}
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
@@ -246,7 +274,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
:return: The mapped key.
"""
key_mapping = {
"agentIdentifier": "agent_identifier",
"agentIdentifier": None,
"numEpisodes": "num_episodes",
"timeDelay": "time_delay",
"configFilename": None,

View File

@@ -0,0 +1,216 @@
from __future__ import annotations
import json
from datetime import datetime
from pathlib import Path
from typing import Final, Optional, Union
from uuid import uuid4
from primaite import getLogger, SESSIONS_DIR
from primaite.config.training_config import TrainingConfig
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _get_session_path(session_timestamp: datetime) -> Path:
"""
Get the directory path the session will output to.
This is set in the format of:
~/primaite/sessions/<yyyy-mm-dd>/<yyyy-mm-dd>_<hh-mm-ss>.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = SESSIONS_DIR / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
_LOGGER.debug(f"Created PrimAITE Session path: {session_path}")
return session_path
class PrimaiteSession:
def __init__(
self,
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
auto: bool = True
):
if not isinstance(training_config_path, Path):
training_config_path = Path(training_config_path)
self._training_config_path: Final[Union[Path]] = training_config_path
if not isinstance(lay_down_config_path, Path):
lay_down_config_path = Path(lay_down_config_path)
self._lay_down_config_path: Final[Union[Path]] = lay_down_config_path
self._auto: Final[bool] = auto
self._uuid: str = str(uuid4())
self._session_timestamp: Final[datetime] = datetime.now()
self._session_path: Final[Path] = _get_session_path(
self._session_timestamp
)
self._timestamp_str: Final[str] = self._session_timestamp.strftime(
"%Y-%m-%d_%H-%M-%S")
self._metadata_path = self._session_path / "session_metadata.json"
self._env = None
self._training_config = None
self._can_learn: bool = False
_LOGGER.debug("")
if self._auto:
self.setup()
self.learn()
@property
def uuid(self):
"""The session UUID."""
return self._uuid
def _setup_primaite_env(self, transaction_list: Optional[list] = None):
if not transaction_list:
transaction_list = []
self._env: Primaite = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=transaction_list,
session_path=self._session_path,
timestamp_str=self._timestamp_str
)
self._training_config: TrainingConfig = self._env.training_config
def _write_session_metadata_file(self):
"""
Write the ``session_metadata.json`` file.
Creates a ``session_metadata.json`` in the ``session_dir`` directory
and adds the following key/value pairs:
- uuid: The UUID assigned to the session upon instantiation.
- start_datetime: The date & time the session started in iso format.
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
- env:
- training_config:
- All training config items
- lay_down_config:
- All lay down config items
"""
metadata_dict = {
"uuid": self._uuid,
"start_datetime": self._session_timestamp.isoformat(),
"end_datetime": None,
"total_episodes": None,
"total_time_steps": None,
"env": {
"training_config": self._env.training_config.to_dict(
json_serializable=True
),
"lay_down_config": self._env.lay_down_config,
},
}
_LOGGER.debug(f"Writing Session Metadata file: {self._metadata_path}")
with open(self._metadata_path, "w") as file:
json.dump(metadata_dict, file)
def _update_session_metadata_file(self):
"""
Update the ``session_metadata.json`` file.
Updates the `session_metadata.json`` in the ``session_dir`` directory
with the following key/value pairs:
- end_datetime: NULL.
- total_episodes: NULL.
- total_time_steps: NULL.
"""
with open(self._metadata_path, "r") as file:
metadata_dict = json.load(file)
metadata_dict["end_datetime"] = datetime.now().isoformat()
metadata_dict["total_episodes"] = self._env.episode_count
metadata_dict["total_time_steps"] = self._env.total_step_count
_LOGGER.debug(f"Updating Session Metadata file: {self._metadata_path}")
with open(self._metadata_path, "w") as file:
json.dump(metadata_dict, file)
def setup(self):
self._setup_primaite_env()
self._can_learn = True
pass
def learn(
self,
time_steps: Optional[int],
episodes: Optional[int],
iterations: Optional[int],
**kwargs
):
if self._can_learn:
# Run environment against an agent
if self._training_config.agent_identifier == "GENERIC":
run_generic(env=env, config_values=config_values)
elif self._training_config == "STABLE_BASELINES3_PPO":
run_stable_baselines3_ppo(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
elif self._training_config == "STABLE_BASELINES3_A2C":
run_stable_baselines3_a2c(
env=env,
config_values=config_values,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Session finished")
_LOGGER.debug("Session finished")
print("Saving transaction logs...")
write_transaction_to_file(
transaction_list=transaction_list,
session_path=session_dir,
timestamp_str=timestamp_str,
)
print("Updating Session Metadata file...")
_update_session_metadata_file(session_dir=session_dir, env=env)
print("Finished")
_LOGGER.debug("Finished")
def evaluate(
self,
time_steps: Optional[int],
episodes: Optional[int],
**kwargs
):
pass
def export(self):
pass
@classmethod
def import_agent(
cls,
gent_path: str,
training_config_path: str,
lay_down_config_path: str
) -> PrimaiteSession:
session = PrimaiteSession(training_config_path, lay_down_config_path)
# Reset the UUID
session._uuid = ""
return session

View File

@@ -1,11 +1,20 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets which agent algorithm framework will be used:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray[RLlib])
# "NONE" (Custom Agent)
agent_framework: RLLIB
# Sets which Red Agent algo/class will be used:
# "PPO" (Proximal Policy Optimization)
# "A2C" (Advantage Actor Critic)
# "HARDCODED" (Custom Agent)
# "RANDOM" (Random Action)
red_agent_identifier: PPO
# Sets How the Action Space is defined:
# "NODE"
# "ACL"