temp commit
This commit is contained in:
@@ -2,6 +2,8 @@
|
||||
from pathlib import Path
|
||||
from typing import Final
|
||||
|
||||
import networkx
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
@@ -9,6 +11,12 @@ _LOGGER = getLogger(__name__)
|
||||
_EXAMPLE_LAY_DOWN: Final[Path] = USERS_CONFIG_DIR / "example_config" / "lay_down"
|
||||
|
||||
|
||||
# class LayDownConfig:
|
||||
# network: networkx.Graph
|
||||
# POL
|
||||
# EIR
|
||||
# ACL
|
||||
|
||||
def ddos_basic_one_config_path() -> Path:
|
||||
"""
|
||||
The path to the example lay_down_config_1_DDOS_basic.yaml file.
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union, Optional
|
||||
@@ -6,7 +8,8 @@ from typing import Any, Dict, Final, Union, Optional
|
||||
import yaml
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import ActionType
|
||||
from primaite.common.enums import ActionType, RedAgentIdentifier, \
|
||||
AgentFramework, SessionType
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -16,10 +19,11 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
|
||||
@dataclass()
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
agent_framework: AgentFramework = AgentFramework.SB3
|
||||
"The agent framework."
|
||||
|
||||
# Generic
|
||||
agent_identifier: str = "STABLE_BASELINES3_A2C"
|
||||
"The Red Agent algo/class to be used."
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
|
||||
"The red agent/algo class."
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use."
|
||||
@@ -38,8 +42,8 @@ class TrainingConfig:
|
||||
"The delay between steps (ms). Applies to generic agents only."
|
||||
|
||||
# file
|
||||
session_type: str = "TRAINING"
|
||||
"the session type to run (TRAINING or EVALUATION)"
|
||||
session_type: SessionType = SessionType.TRAINING
|
||||
"The type of PrimAITE session to run."
|
||||
|
||||
load_agent: str = False
|
||||
"Determine whether to load an agent from file."
|
||||
@@ -137,6 +141,24 @@ class TrainingConfig:
|
||||
file_system_scanning_limit: int = 5
|
||||
"The time taken to scan the file system."
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
config_dict: Dict[str, Union[str, int, bool]]
|
||||
) -> TrainingConfig:
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"red_agent_identifier": RedAgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType
|
||||
}
|
||||
|
||||
for field, enum_class in field_enum_map.items():
|
||||
if field in config_dict:
|
||||
config_dict[field] = enum_class[field]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
def to_dict(self, json_serializable: bool = True):
|
||||
"""
|
||||
Serialise the ``TrainingConfig`` as dict.
|
||||
@@ -196,10 +218,8 @@ def load(file_path: Union[str, Path],
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
# Convert values to Enums
|
||||
config["action_type"] = ActionType[config["action_type"]]
|
||||
try:
|
||||
return TrainingConfig(**config)
|
||||
return TrainingConfig.from_dict(**config)
|
||||
except TypeError as e:
|
||||
msg = (
|
||||
f"Error when creating an instance of {TrainingConfig} "
|
||||
@@ -214,22 +234,30 @@ def load(file_path: Union[str, Path],
|
||||
|
||||
def convert_legacy_training_config_dict(
|
||||
legacy_config_dict: Dict[str, Any],
|
||||
num_steps: int = 256,
|
||||
action_type: str = "ANY"
|
||||
agent_framework: AgentFramework = AgentFramework.SB3,
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO,
|
||||
action_type: ActionType = ActionType.ANY,
|
||||
num_steps: int = 256
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Convert a legacy training config dict to the new format.
|
||||
|
||||
:param legacy_config_dict: A legacy training config dict.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:param agent_framework: The agent framework to use as legacy training
|
||||
configs don't have agent_framework values.
|
||||
:param red_agent_identifier: The red agent identifier to use as legacy
|
||||
training configs don't have red_agent_identifier values.
|
||||
:param action_type: The action space type to set as legacy training configs
|
||||
don't have action_type values.
|
||||
:param num_steps: The number of steps to set as legacy training configs
|
||||
don't have num_steps values.
|
||||
:return: The converted training config dict.
|
||||
"""
|
||||
config_dict = {
|
||||
"num_steps": num_steps,
|
||||
"action_type": action_type
|
||||
"agent_framework": agent_framework.name,
|
||||
"red_agent_identifier": red_agent_identifier.name,
|
||||
"action_type": action_type.name,
|
||||
"num_steps": num_steps
|
||||
}
|
||||
for legacy_key, value in legacy_config_dict.items():
|
||||
new_key = _get_new_key_from_legacy(legacy_key)
|
||||
@@ -246,7 +274,7 @@ def _get_new_key_from_legacy(legacy_key: str) -> str:
|
||||
:return: The mapped key.
|
||||
"""
|
||||
key_mapping = {
|
||||
"agentIdentifier": "agent_identifier",
|
||||
"agentIdentifier": None,
|
||||
"numEpisodes": "num_episodes",
|
||||
"timeDelay": "time_delay",
|
||||
"configFilename": None,
|
||||
|
||||
Reference in New Issue
Block a user