temp commit

This commit is contained in:
Chris McCarthy
2023-06-13 09:42:54 +01:00
parent 9b0e24c27b
commit eb3368edd6
11 changed files with 626 additions and 173 deletions

View File

@@ -2,6 +2,8 @@
from pathlib import Path
from typing import Final
import networkx
from primaite import USERS_CONFIG_DIR, getLogger
_LOGGER = getLogger(__name__)
@@ -9,6 +11,12 @@ _LOGGER = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
# class LayDownConfig:
# network: networkx.Graph
# POL
# EIR
# ACL
def ddos_basic_one_config_path() -> Path:
"""
The path to the example lay_down_config_1_DDOS_basic.yaml file.

View File

@@ -1,4 +1,6 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
from __future__ import annotations
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, Final, Union, Optional
@@ -6,7 +8,8 @@ from typing import Any, Dict, Final, Union, Optional
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import ActionType
from primaite.common.enums import ActionType, RedAgentIdentifier, \
AgentFramework, SessionType
_LOGGER = getLogger(__name__)
@@ -16,10 +19,11 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
@dataclass()
class TrainingConfig:
"""The Training Config class."""
agent_framework: AgentFramework = AgentFramework.SB3
"The agent framework."
# Generic
agent_identifier: str = "STABLE_BASELINES3_A2C"
"The Red Agent algo/class to be used."
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
"The red agent/algo class."
action_type: ActionType = ActionType.ANY
"The ActionType to use."
@@ -38,8 +42,8 @@ class TrainingConfig:
"The delay between steps (ms). Applies to generic agents only."
# file
session_type: str = "TRAINING"
"the session type to run (TRAINING or EVALUATION)"
session_type: SessionType = SessionType.TRAINING
"The type of PrimAITE session to run."
load_agent: str = False
"Determine whether to load an agent from file."
@@ -137,6 +141,24 @@ class TrainingConfig:
file_system_scanning_limit: int = 5
"The time taken to scan the file system."
@classmethod
def from_dict(
cls,
config_dict: Dict[str, Union[str, int, bool]]
) -> TrainingConfig:
field_enum_map = {
"agent_framework": AgentFramework,
"red_agent_identifier": RedAgentIdentifier,
"action_type": ActionType,
"session_type": SessionType
}
for field, enum_class in field_enum_map.items():
if field in config_dict:
config_dict[field] = enum_class[field]
return TrainingConfig(**config_dict)
def to_dict(self, json_serializable: bool = True):
"""
Serialise the ``TrainingConfig`` as dict.
@@ -196,10 +218,8 @@ def load(file_path: Union[str, Path],
f"from legacy format. Attempting to use file as is."
)
_LOGGER.error(msg)
# Convert values to Enums
config["action_type"] = ActionType[config["action_type"]]
try:
return TrainingConfig(**config)
return TrainingConfig.from_dict(**config)
except TypeError as e:
msg = (
f"Error when creating an instance of {TrainingConfig} "
@@ -214,22 +234,30 @@ def load(file_path: Union[str, Path],
def convert_legacy_training_config_dict(
legacy_config_dict: Dict[str, Any],
num_steps: int = 256,
action_type: str = "ANY"
agent_framework: AgentFramework = AgentFramework.SB3,
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_steps: int = 256
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
:param legacy_config_dict: A legacy training config dict.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:param agent_framework: The agent framework to use as legacy training
configs don't have agent_framework values.
:param red_agent_identifier: The red agent identifier to use as legacy
training configs don't have red_agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_steps: The number of steps to set as legacy training configs
don't have num_steps values.
:return: The converted training config dict.
"""
config_dict = {
"num_steps": num_steps,
"action_type": action_type
"agent_framework": agent_framework.name,
"red_agent_identifier": red_agent_identifier.name,
"action_type": action_type.name,
"num_steps": num_steps
}
for legacy_key, value in legacy_config_dict.items():
new_key = _get_new_key_from_legacy(legacy_key)
@@ -246,7 +274,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
:return: The mapped key.
"""
key_mapping = {
"agentIdentifier": "agent_identifier",
"agentIdentifier": None,
"numEpisodes": "num_episodes",
"timeDelay": "time_delay",
"configFilename": None,