Merge branch 'dev' into feature/1623-typehints
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import logging
|
||||
import logging.config
|
||||
import sys
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Access Control List. Models firewall functionality."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements the access control list implementation for the network."""
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A class that implements an access control list rule."""
|
||||
|
||||
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Common interface between RL agents from different libraries and PrimAITE."""
|
||||
|
||||
@@ -1,28 +1,24 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, TYPE_CHECKING, Union
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
from uuid import uuid4
|
||||
|
||||
import yaml
|
||||
|
||||
import primaite
|
||||
from primaite import getLogger, SESSIONS_DIR
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.data_viz.session_plots import plot_av_reward_per_episode
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from logging import Logger
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
_LOGGER: "Logger" = getLogger(__name__)
|
||||
|
||||
|
||||
@@ -53,38 +49,63 @@ class AgentSessionABC(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an agent session from config files.
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
|
||||
If training configuration and laydown configuration are provided with a session path,
|
||||
the session path will be used.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# initialise variables
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
|
||||
self._uuid = str(uuid4())
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
# convert session to path
|
||||
if session_path is not None:
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
# load session
|
||||
self.load(session_path)
|
||||
else:
|
||||
# set training config path
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
self._uuid = str(uuid4())
|
||||
"The session timestamp"
|
||||
self.session_path = get_session_path(self.session_timestamp)
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def timestamp_str(self) -> str:
|
||||
@@ -233,51 +254,27 @@ class AgentSessionABC(ABC):
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls, path: Union[str, Path]) -> AgentSessionABC:
|
||||
def load(self, path: Union[str, Path]):
|
||||
"""Load an agent from file."""
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
md_dict, training_config_path, laydown_config_path = parse_session_metadata(path)
|
||||
|
||||
if path.exists():
|
||||
# Unpack the session_metadata.json file
|
||||
md_file = path / "session_metadata.json"
|
||||
with open(md_file, "r") as file:
|
||||
md_dict = json.load(file)
|
||||
# set training config path
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = laydown_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# Create a temp directory and dump the training and lay down
|
||||
# configs into it
|
||||
temp_dir = path / ".temp"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
# set random UUID for session
|
||||
self._uuid = md_dict["uuid"]
|
||||
|
||||
temp_tc = temp_dir / "tc.yaml"
|
||||
with open(temp_tc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["training_config"], file)
|
||||
|
||||
temp_ldc = temp_dir / "ldc.yaml"
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
agent = cls(temp_tc, temp_ldc)
|
||||
|
||||
agent.session_path = path
|
||||
|
||||
return agent
|
||||
|
||||
else:
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
# set the session path
|
||||
self.session_path = path
|
||||
"The Session path"
|
||||
|
||||
@property
|
||||
def _saved_agent_path(self) -> Path:
|
||||
file_name = (
|
||||
f"{self._training_config.agent_framework}_"
|
||||
f"{self._training_config.agent_identifier}_"
|
||||
f"{self.timestamp_str}.zip"
|
||||
)
|
||||
file_name = f"{self._training_config.agent_framework}_" f"{self._training_config.agent_identifier}" f".zip"
|
||||
return self.learning_path / file_name
|
||||
|
||||
@abstractmethod
|
||||
@@ -313,104 +310,3 @@ class AgentSessionABC(ABC):
|
||||
fig = plot_av_reward_per_episode(path, title, subtitle)
|
||||
fig.write_image(image_path)
|
||||
_LOGGER.debug(f"Saved average rewards per episode plot to: {path}")
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self) -> None:
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self) -> None:
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs: np.ndarray) -> None:
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls) -> None:
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self) -> None:
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
116
src/primaite/agents/hardcoded_abc.py
Normal file
116
src/primaite/agents/hardcoded_abc.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import time
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HardCodedAgentSessionABC(AgentSessionABC):
|
||||
"""
|
||||
An Agent Session ABC for evaluation deterministic agents.
|
||||
|
||||
This class cannot be directly instantiated and must be inherited from with all implemented abstract methods
|
||||
implemented.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
):
|
||||
"""
|
||||
Initialise a hardcoded agent session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
self._setup()
|
||||
|
||||
def _setup(self):
|
||||
self._env: Primaite = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
super()._setup()
|
||||
self._can_learn = False
|
||||
self._can_evaluate = True
|
||||
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Train the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
|
||||
@abstractmethod
|
||||
def _calculate_action(self, obs):
|
||||
pass
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Evaluate the agent.
|
||||
|
||||
:param kwargs: Any agent-specific key-word args to be passed.
|
||||
"""
|
||||
self._env.set_as_eval() # noqa
|
||||
self.is_eval = True
|
||||
|
||||
time_steps = self._training_config.num_eval_steps
|
||||
episodes = self._training_config.num_eval_episodes
|
||||
|
||||
obs = self._env.reset()
|
||||
for episode in range(episodes):
|
||||
# Reset env and collect initial observation
|
||||
for step in range(time_steps):
|
||||
# Calculate action
|
||||
action = self._calculate_action(obs)
|
||||
|
||||
# Perform the step
|
||||
obs, reward, done, info = self._env.step(action)
|
||||
|
||||
if done:
|
||||
break
|
||||
|
||||
# Introduce a delay between steps
|
||||
time.sleep(self._training_config.time_delay / 1000)
|
||||
obs = self._env.reset()
|
||||
self._env.close()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path=None):
|
||||
"""Load an agent from file."""
|
||||
_LOGGER.warning("Deterministic agents cannot be loaded")
|
||||
|
||||
def save(self):
|
||||
"""Save the agent."""
|
||||
_LOGGER.warning("Deterministic agents cannot be saved")
|
||||
|
||||
def export(self):
|
||||
"""Export the agent to transportable file format."""
|
||||
_LOGGER.warning("Deterministic agents cannot be exported")
|
||||
@@ -1,10 +1,11 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from primaite.acl.access_control_list import AccessControlList
|
||||
from primaite.acl.acl_rule import ACLRule
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import (
|
||||
get_new_action,
|
||||
get_node_of_ip,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import numpy as np
|
||||
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_node_enum, transform_change_obs_readable
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from ray.rllib.algorithms import Algorithm
|
||||
@@ -14,7 +15,7 @@ from ray.tune.logger import UnifiedLogger
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
@@ -48,7 +49,12 @@ def _custom_log_creator(session_path: Path) -> Callable[[Dict], UnifiedLogger]:
|
||||
class RLlibAgent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Ray RLlib agent."""
|
||||
|
||||
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the RLLib Agent training session.
|
||||
|
||||
@@ -61,6 +67,13 @@ class RLlibAgent(AgentSessionABC):
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
# TODO: implement RLlib agent loading
|
||||
if session_path is not None:
|
||||
msg = "RLlib agent loading has not been implemented yet"
|
||||
_LOGGER.error(msg)
|
||||
print(msg)
|
||||
raise NotImplementedError
|
||||
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
msg = f"Expected RLLIB agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import numpy as np
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.common.enums import AgentFramework, AgentIdentifier
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
@@ -21,7 +23,12 @@ _LOGGER: "Logger" = getLogger(__name__)
|
||||
class SB3Agent(AgentSessionABC):
|
||||
"""An AgentSession class that implements a Stable Baselines3 agent."""
|
||||
|
||||
def __init__(self, training_config_path: Union[str, Path], lay_down_config_path: Union[str, Path]) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
|
||||
@@ -34,7 +41,7 @@ class SB3Agent(AgentSessionABC):
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
@@ -51,7 +58,7 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self._tensorboard_log_path = self.learning_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
self._setup()
|
||||
|
||||
_LOGGER.debug(
|
||||
f"Created {self.__class__.__name__} using: "
|
||||
f"agent_framework={self._training_config.agent_framework}, "
|
||||
@@ -61,8 +68,10 @@ class SB3Agent(AgentSessionABC):
|
||||
|
||||
self.is_eval = False
|
||||
|
||||
self._setup()
|
||||
|
||||
def _setup(self) -> None:
|
||||
super()._setup()
|
||||
"""Set up the SB3 Agent."""
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
@@ -70,14 +79,43 @@ class SB3Agent(AgentSessionABC):
|
||||
timestamp_str=self.timestamp_str,
|
||||
)
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_train_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
# check if there is a zip file that needs to be loaded
|
||||
load_file = next(self.session_path.rglob("*.zip"), None)
|
||||
|
||||
if not load_file:
|
||||
# create a new env and agent
|
||||
|
||||
self._agent = self._agent_class(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=self.sb3_output_verbose_level,
|
||||
n_steps=self._training_config.num_train_steps,
|
||||
tensorboard_log=str(self._tensorboard_log_path),
|
||||
seed=self._training_config.seed,
|
||||
)
|
||||
else:
|
||||
# set env values from session metadata
|
||||
with open(self.session_path / "session_metadata.json", "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# load environment values
|
||||
if self.is_eval:
|
||||
# evaluation always starts at 0
|
||||
self._env.episode_count = 0
|
||||
self._env.total_step_count = 0
|
||||
else:
|
||||
# carry on from previous learning sessions
|
||||
self._env.episode_count = md_dict["learning"]["total_episodes"]
|
||||
self._env.total_step_count = md_dict["learning"]["total_time_steps"]
|
||||
|
||||
# load the file
|
||||
self._agent = self._agent_class.load(load_file, env=self._env)
|
||||
|
||||
# set agent values
|
||||
self._agent.verbose = self.sb3_output_verbose_level
|
||||
self._agent.tensorboard_log = self.session_path / "learning/tensorboard_logs"
|
||||
|
||||
super()._setup()
|
||||
|
||||
def _save_checkpoint(self) -> None:
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
@@ -149,11 +187,6 @@ class SB3Agent(AgentSessionABC):
|
||||
self._env.close()
|
||||
super().evaluate()
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Union[str, Path]) -> SB3Agent:
|
||||
"""Load an agent from file."""
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self) -> None:
|
||||
"""Save the agent."""
|
||||
self._agent.save(self._saved_agent_path)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite.agents.agent import HardCodedAgentSessionABC
|
||||
from primaite.agents.hardcoded_abc import HardCodedAgentSessionABC
|
||||
from primaite.agents.utils import get_new_action, transform_action_acl_enum, transform_action_node_enum
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Provides a CLI using Typer as an entry point."""
|
||||
import logging
|
||||
import os
|
||||
@@ -151,7 +151,7 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
@@ -162,11 +162,19 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None) -> None:
|
||||
ldc: The lay down config file path. Optional. If no value is passed then
|
||||
example default lay down config is used from:
|
||||
~/primaite/config/example_config/lay_down/lay_down_config_3_doc_very_basic.yaml.
|
||||
|
||||
load: The directory of a previous session. Optional. If no value is passed, then the session
|
||||
will use the default training config and laydown config. Inversely, if a training config and laydown config
|
||||
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
|
||||
and laydown config.
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if load is not None:
|
||||
run(session_path=load)
|
||||
|
||||
if not tc:
|
||||
tc = main_training_config_path()
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Objects which are shared between many PrimAITE modules."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Enumerations for APE."""
|
||||
|
||||
from enum import Enum, IntEnum
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The protocol class."""
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The Service class."""
|
||||
|
||||
from primaite.common.enums import SoftwareState
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Configuration parameters for running experiments."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Utility to generate plots of sessions metrics after PrimAITE."""
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Gym/Gymnasium environment for RL agents consisting of a simulated computer network."""
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Main environment module containing the PRIMmary AI Training Evironment (Primaite) class."""
|
||||
import copy
|
||||
import logging
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements reward function."""
|
||||
from typing import Dict, TYPE_CHECKING, Union
|
||||
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Network connections between nodes in the simulation."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The link class."""
|
||||
from typing import List
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.primaite_session import PrimaiteSession
|
||||
@@ -11,16 +11,21 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: The training config filepath.
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path)
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
|
||||
|
||||
session.setup()
|
||||
session.learn()
|
||||
@@ -31,9 +36,14 @@ if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--tc")
|
||||
parser.add_argument("--ldc")
|
||||
parser.add_argument("--load")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
if args.load:
|
||||
run(session_path=args.load)
|
||||
else:
|
||||
if not args.tc:
|
||||
_LOGGER.error("Please provide a training config file using the --tc " "argument")
|
||||
if not args.ldc:
|
||||
_LOGGER.error("Please provide a lay down config file using the --ldc " "argument")
|
||||
run(training_config_path=args.tc, lay_down_config_path=args.ldc)
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Nodes represent network hosts in the simulation."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""An Active Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Final
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The base Node class."""
|
||||
from typing import Final
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Defines node behaviour for Green PoL."""
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Union
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The Passive Node class (i.e. an actuator)."""
|
||||
from primaite.common.enums import HardwareState, NodeType, Priority
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""A Service Node (i.e. not an actuator)."""
|
||||
import logging
|
||||
from typing import Dict, Final
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Contains default jupyter notebooks which demonstrate PrimAITE functionality."""
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Pattern of Life- Represents the actions of users on the network."""
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements Pattern of Life on the network (nodes and links)."""
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""
|
||||
Information Exchange Requirements for APE.
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Implements POL on the network (nodes and links) resulting from the red agent attack."""
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Main entry point to PrimAITE. Configure training/evaluation experiments and input/output."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Union
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.agents.agent import AgentSessionABC
|
||||
from primaite.agents.agent_abc import AgentSessionABC
|
||||
from primaite.agents.hardcoded_acl import HardCodedACLAgent
|
||||
from primaite.agents.hardcoded_node import HardCodedNodeAgent
|
||||
from primaite.agents.rllib import RLlibAgent
|
||||
@@ -14,6 +15,7 @@ from primaite.agents.simple import DoNothingACLAgent, DoNothingNodeAgent, DummyA
|
||||
from primaite.common.enums import ActionType, AgentFramework, AgentIdentifier, SessionType
|
||||
from primaite.config import lay_down_config, training_config
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.utils.session_metadata_parser import parse_session_metadata
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -27,15 +29,39 @@ class PrimaiteSession:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
|
||||
:param training_config_path: The training config path.
|
||||
:param lay_down_config_path: The lay down config path.
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = session_path # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
|
||||
# check if session path is provided
|
||||
if session_path is not None:
|
||||
# set load_session to true
|
||||
self.is_load_session = True
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
# if a session path is provided, load it
|
||||
if not session_path.exists():
|
||||
raise Exception(f"Session could not be loaded. Path does not exist: {session_path}")
|
||||
|
||||
md_dict, training_config_path, lay_down_config_path = parse_session_metadata(session_path)
|
||||
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
@@ -60,11 +86,15 @@ class PrimaiteSession:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.HARDCODED}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
# Deterministic Hardcoded Agent with Node Action Space
|
||||
self._agent_session = HardCodedNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = HardCodedNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = HardCodedACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = HardCodedACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -77,11 +107,15 @@ class PrimaiteSession:
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DO_NOTHING:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DO_NOTHING}")
|
||||
if self._training_config.action_type == ActionType.NODE:
|
||||
self._agent_session = DoNothingNodeAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DoNothingNodeAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ACL:
|
||||
# Deterministic Hardcoded Agent with ACL Action Space
|
||||
self._agent_session = DoNothingACLAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DoNothingACLAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
elif self._training_config.action_type == ActionType.ANY:
|
||||
# Deterministic Hardcoded Agent with ANY Action Space
|
||||
@@ -93,10 +127,14 @@ class PrimaiteSession:
|
||||
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.RANDOM:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.RANDOM}")
|
||||
self._agent_session = RandomAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = RandomAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
elif self._training_config.agent_identifier == AgentIdentifier.DUMMY:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Identifier =" f" {AgentIdentifier.DUMMY}")
|
||||
self._agent_session = DummyAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = DummyAgent(
|
||||
self._training_config_path, self._lay_down_config_path, self.session_path
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework AgentIdentifier combo
|
||||
@@ -105,12 +143,12 @@ class PrimaiteSession:
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path)
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Utilities to prepare the user's data folders."""
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import filecmp
|
||||
import os
|
||||
import shutil
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from primaite import _USER_DIRS, getLogger, LOG_DIR, NOTEBOOKS_DIR
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Record data of the system's state and agent's observations and actions."""
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""The Transaction class."""
|
||||
from datetime import datetime
|
||||
from typing import List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
@@ -1 +1,2 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
"""Utilities for PrimAITE."""
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
59
src/primaite/utils/session_metadata_parser.py
Normal file
59
src/primaite/utils/session_metadata_parser.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def parse_session_metadata(session_path: Union[Path, str], dict_only=False):
|
||||
"""
|
||||
Loads a session metadata from the given directory path.
|
||||
|
||||
:param session_path: Directory where the session metadata file is in
|
||||
:param dict_only: If dict_only is true, the function will only return the dict contents of session metadata
|
||||
|
||||
:return: Dictionary which has all the session metadata contents
|
||||
:rtype: Dict
|
||||
|
||||
:return: Path where the YAML copy of the training config is dumped into
|
||||
:rtype: str
|
||||
:return: Path where the YAML copy of the laydown config is dumped into
|
||||
:rtype: str
|
||||
"""
|
||||
if not isinstance(session_path, Path):
|
||||
session_path = Path(session_path)
|
||||
|
||||
if not session_path.exists():
|
||||
# Session path does not exist
|
||||
msg = f"Failed to load PrimAITE Session, path does not exist: {session_path}"
|
||||
_LOGGER.error(msg)
|
||||
raise FileNotFoundError(msg)
|
||||
|
||||
# Unpack the session_metadata.json file
|
||||
md_file = session_path / "session_metadata.json"
|
||||
with open(md_file, "r") as file:
|
||||
md_dict = json.load(file)
|
||||
|
||||
# if dict only, return dict without doing anything else
|
||||
if dict_only:
|
||||
return md_dict
|
||||
|
||||
# Create a temp directory and dump the training and lay down
|
||||
# configs into it
|
||||
temp_dir = session_path / ".temp"
|
||||
temp_dir.mkdir(exist_ok=True)
|
||||
|
||||
temp_tc = temp_dir / "tc.yaml"
|
||||
with open(temp_tc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["training_config"], file)
|
||||
|
||||
temp_ldc = temp_dir / "ldc.yaml"
|
||||
with open(temp_ldc, "w") as file:
|
||||
yaml.dump(md_dict["env"]["lay_down_config"], file)
|
||||
|
||||
return [md_dict, temp_tc, temp_ldc]
|
||||
@@ -1,5 +1,6 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
# Using polars as it's faster than Pandas; it will speed things up when
|
||||
# files get big!
|
||||
@@ -13,8 +14,33 @@ def av_rewards_dict(av_rewards_csv_file: Union[str, Path]) -> Dict[int, float]:
|
||||
The dictionary keys are the episode number, and the values are the mean reward that episode.
|
||||
|
||||
:param av_rewards_csv_file: The average rewards per episode csv file path.
|
||||
:return: The average rewards per episode cdv as a dict.
|
||||
:return: The average rewards per episode csv as a dict.
|
||||
"""
|
||||
df = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
df_dict = pl.read_csv(av_rewards_csv_file).to_dict()
|
||||
|
||||
return {v: df["Average Reward"][i] for i, v in enumerate(df["Episode"])}
|
||||
return {v: df_dict["Average Reward"][i] for i, v in enumerate(df_dict["Episode"])}
|
||||
|
||||
|
||||
def all_transactions_dict(all_transactions_csv_file: Union[str, Path]) -> Dict[Tuple[int, int], Dict[str, Any]]:
|
||||
"""
|
||||
Read an all transactions csv file and return as a dict.
|
||||
|
||||
The dict keys are a tuple with the structure (episode, step). The dict
|
||||
values are the remaining columns as a dict.
|
||||
|
||||
:param all_transactions_csv_file: The all transactions csv file path.
|
||||
:return: The all transactions csv file as a dict.
|
||||
"""
|
||||
df_dict = pl.read_csv(all_transactions_csv_file).to_dict()
|
||||
new_dict = {}
|
||||
|
||||
episodes = df_dict["Episode"]
|
||||
steps = df_dict["Step"]
|
||||
keys = list(df_dict.keys())
|
||||
|
||||
for i in range(len(episodes)):
|
||||
key = (episodes[i], steps[i])
|
||||
value_dict = {key: df_dict[key][i] for key in keys if key not in ["Episode", "Step"]}
|
||||
new_dict[key] = value_dict
|
||||
|
||||
return new_dict
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
# Crown Owned Copyright (C) Dstl 2023. DEFCON 703. Shared in confidence.
|
||||
import csv
|
||||
from logging import Logger
|
||||
from typing import Final, List, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
Reference in New Issue
Block a user