#1711 - Fully Integrated the legacy training config and lay down config options into the CLI, run PrimaiteSession, and Agent classes. Made the ese test in test_full_legacy_config_session.py use this new integrated option to read the legacy file.
This commit is contained in:
@@ -52,6 +52,8 @@ class AgentSessionABC(ABC):
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an agent session from config files, or load a previous session.
|
||||
@@ -64,6 +66,10 @@ class AgentSessionABC(ABC):
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param session_path: directory path of the session to load
|
||||
"""
|
||||
# initialise variables
|
||||
@@ -72,6 +78,8 @@ class AgentSessionABC(ABC):
|
||||
self._can_learn: bool = False
|
||||
self._can_evaluate: bool = False
|
||||
self.is_eval = False
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
|
||||
@@ -91,12 +99,14 @@ class AgentSessionABC(ABC):
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Union[Path, str] = training_config_path
|
||||
self._training_config: TrainingConfig = training_config.load(self._training_config_path)
|
||||
self._training_config: TrainingConfig = training_config.load(
|
||||
self._training_config_path, legacy_file=legacy_training_config
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Union[Path, str] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path)
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config)
|
||||
self.sb3_output_verbose_level = self._training_config.sb3_output_verbose_level
|
||||
|
||||
# set random UUID for session
|
||||
|
||||
@@ -26,6 +26,8 @@ class SB3Agent(AgentSessionABC):
|
||||
training_config_path: Optional[Union[str, Path]] = None,
|
||||
lay_down_config_path: Optional[Union[str, Path]] = None,
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise the SB3 Agent training session.
|
||||
@@ -35,11 +37,17 @@ class SB3Agent(AgentSessionABC):
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_framework (should be "SB3")
|
||||
:raises ValueError: If the training config contains an unexpected value for agent_identifies (should be `PPO`
|
||||
or `A2C`)
|
||||
"""
|
||||
super().__init__(training_config_path, lay_down_config_path, session_path)
|
||||
super().__init__(
|
||||
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
)
|
||||
if not self._training_config.agent_framework == AgentFramework.SB3:
|
||||
msg = f"Expected SB3 agent_framework, " f"got {self._training_config.agent_framework}"
|
||||
_LOGGER.error(msg)
|
||||
@@ -75,6 +83,8 @@ class SB3Agent(AgentSessionABC):
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str,
|
||||
legacy_training_config=self.legacy_training_config,
|
||||
legacy_lay_down_config=self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
# check if there is a zip file that needs to be loaded
|
||||
|
||||
@@ -18,9 +18,9 @@ app = typer.Typer()
|
||||
@app.command()
|
||||
def build_dirs() -> None:
|
||||
"""Build the PrimAITE app directories."""
|
||||
from primaite.setup import setup_app_dirs
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
setup_app_dirs.run()
|
||||
PRIMAITE_PATHS.mkdirs()
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -137,7 +137,13 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[str] = None) -> None:
|
||||
def session(
|
||||
tc: Optional[str] = None,
|
||||
ldc: Optional[str] = None,
|
||||
load: Optional[str] = None,
|
||||
legacy_tc: bool = False,
|
||||
legacy_ldc: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
@@ -153,6 +159,10 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
|
||||
will use the default training config and laydown config. Inversely, if a training config and laydown config
|
||||
is passed while a session directory is passed, PrimAITE will load the session and ignore the training config
|
||||
and laydown config.
|
||||
|
||||
legacy_tc: If the training config file is a legacy file from PrimAITE < 2.0.
|
||||
|
||||
legacy_ldf: If the lay down config file is a legacy file from PrimAITE < 2.0.
|
||||
"""
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
@@ -170,7 +180,12 @@ def session(tc: Optional[str] = None, ldc: Optional[str] = None, load: Optional[
|
||||
if not ldc:
|
||||
ldc = dos_very_basic_config_path()
|
||||
|
||||
run(training_config_path=tc, lay_down_config_path=ldc)
|
||||
run(
|
||||
training_config_path=tc,
|
||||
lay_down_config_path=ldc,
|
||||
legacy_training_config=legacy_tc,
|
||||
legacy_lay_down_config=legacy_ldc,
|
||||
)
|
||||
|
||||
|
||||
@app.command()
|
||||
|
||||
@@ -291,12 +291,14 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> TrainingConf
|
||||
if legacy_file:
|
||||
try:
|
||||
config = convert_legacy_training_config_dict(config)
|
||||
except KeyError:
|
||||
|
||||
except KeyError as e:
|
||||
msg = (
|
||||
f"Failed to convert training config file {file_path} "
|
||||
f"from legacy format. Attempting to use file as is."
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
raise e
|
||||
try:
|
||||
return TrainingConfig.from_dict(config)
|
||||
except TypeError as e:
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Any, Dict, Final, List, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import yaml
|
||||
from gym import Env, spaces
|
||||
from matplotlib import pyplot as plt
|
||||
|
||||
@@ -34,6 +33,7 @@ from primaite.common.enums import (
|
||||
)
|
||||
from primaite.common.service import Service
|
||||
from primaite.config import training_config
|
||||
from primaite.config.lay_down_config import load
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.environment.observations import ObservationsHandler
|
||||
from primaite.environment.reward import calculate_reward_function
|
||||
@@ -68,6 +68,8 @@ class Primaite(Env):
|
||||
lay_down_config_path: Union[str, Path],
|
||||
session_path: Path,
|
||||
timestamp_str: str,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
The Primaite constructor.
|
||||
@@ -76,13 +78,19 @@ class Primaite(Env):
|
||||
:param lay_down_config_path: The lay down config filepath.
|
||||
:param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format: <yyyy-mm-dd>_<hh-mm- ss>.
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
self.session_path: Final[Path] = session_path
|
||||
self.timestamp_str: Final[str] = timestamp_str
|
||||
self._training_config_path: Union[str, Path] = training_config_path
|
||||
self._lay_down_config_path: Union[str, Path] = lay_down_config_path
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
self.training_config: TrainingConfig = training_config.load(training_config_path)
|
||||
self.training_config: TrainingConfig = training_config.load(training_config_path, self.legacy_training_config)
|
||||
_LOGGER.info(f"Using: {str(self.training_config)}")
|
||||
|
||||
# Number of steps in an episode
|
||||
@@ -191,11 +199,8 @@ class Primaite(Env):
|
||||
self._obs_space_description: List[str] = None
|
||||
"The env observation space description for transactions writing"
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
with open(self._lay_down_config_path, "r") as file:
|
||||
# Open the config file and build the environment laydown
|
||||
self.lay_down_config = yaml.safe_load(file)
|
||||
self.load_lay_down_config()
|
||||
self.lay_down_config = load(self._lay_down_config_path, self.legacy_lay_down_config)
|
||||
self.load_lay_down_config()
|
||||
|
||||
# Store the node objects as node attributes
|
||||
# (This is so we can access them as objects)
|
||||
|
||||
@@ -14,18 +14,26 @@ def run(
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
session = PrimaiteSession(training_config_path, lay_down_config_path, session_path)
|
||||
session = PrimaiteSession(
|
||||
training_config_path, lay_down_config_path, session_path, legacy_training_config, legacy_lay_down_config
|
||||
)
|
||||
|
||||
session.setup()
|
||||
session.learn()
|
||||
|
||||
@@ -34,6 +34,8 @@ class PrimaiteSession:
|
||||
training_config_path: Optional[Union[str, Path]] = "",
|
||||
lay_down_config_path: Optional[Union[str, Path]] = "",
|
||||
session_path: Optional[Union[str, Path]] = None,
|
||||
legacy_training_config: bool = False,
|
||||
legacy_lay_down_config: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
The PrimaiteSession constructor.
|
||||
@@ -44,12 +46,18 @@ class PrimaiteSession:
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
self._agent_session: AgentSessionABC = None # noqa
|
||||
self.session_path: Path = session_path # noqa
|
||||
self.timestamp_str: str = None # noqa
|
||||
self.learning_path: Path = None # noqa
|
||||
self.evaluation_path: Path = None # noqa
|
||||
self.legacy_training_config = legacy_training_config
|
||||
self.legacy_lay_down_config = legacy_lay_down_config
|
||||
|
||||
# check if session path is provided
|
||||
if session_path is not None:
|
||||
@@ -67,12 +75,14 @@ class PrimaiteSession:
|
||||
if not isinstance(training_config_path, Path):
|
||||
training_config_path = Path(training_config_path)
|
||||
self._training_config_path: Final[Union[Path, str]] = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(self._training_config_path)
|
||||
self._training_config: Final[TrainingConfig] = training_config.load(
|
||||
self._training_config_path, legacy_training_config
|
||||
)
|
||||
|
||||
if not isinstance(lay_down_config_path, Path):
|
||||
lay_down_config_path = Path(lay_down_config_path)
|
||||
self._lay_down_config_path: Final[Union[Path, str]] = lay_down_config_path
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path) # noqa
|
||||
self._lay_down_config: Dict = lay_down_config.load(self._lay_down_config_path, legacy_lay_down_config) # noqa
|
||||
|
||||
def setup(self) -> None:
|
||||
"""Performs the session setup."""
|
||||
@@ -139,12 +149,24 @@ class PrimaiteSession:
|
||||
elif self._training_config.agent_framework == AgentFramework.SB3:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.SB3}")
|
||||
# Stable Baselines3 Agent
|
||||
self._agent_session = SB3Agent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
self._agent_session = SB3Agent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path,
|
||||
self.session_path,
|
||||
self.legacy_training_config,
|
||||
self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
elif self._training_config.agent_framework == AgentFramework.RLLIB:
|
||||
_LOGGER.debug(f"PrimaiteSession Setup: Agent Framework = {AgentFramework.RLLIB}")
|
||||
# Ray RLlib Agent
|
||||
self._agent_session = RLlibAgent(self._training_config_path, self._lay_down_config_path, self.session_path)
|
||||
self._agent_session = RLlibAgent(
|
||||
self._training_config_path,
|
||||
self._lay_down_config_path,
|
||||
self.session_path,
|
||||
self.legacy_training_config,
|
||||
self.legacy_lay_down_config,
|
||||
)
|
||||
|
||||
else:
|
||||
# Invalid AgentFramework
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.config import training_config
|
||||
from primaite.config.lay_down_config import convert_legacy_lay_down_config
|
||||
from primaite.main import run
|
||||
from tests import TEST_CONFIG_ROOT
|
||||
|
||||
@@ -22,28 +17,13 @@ from tests import TEST_CONFIG_ROOT
|
||||
)
|
||||
def test_legacy_training_config_run_session(legacy_file):
|
||||
"""Tests using legacy training and lay down config files in PrimAITE session end-to-end."""
|
||||
# Load the legacy lay down config yaml file
|
||||
with open(TEST_CONFIG_ROOT / "legacy_conversion" / legacy_file, "r") as file:
|
||||
legacy_lay_down_config = yaml.safe_load(file)
|
||||
|
||||
# Convert the legacy lay down config to the new format
|
||||
converted_lay_down_config = convert_legacy_lay_down_config(legacy_lay_down_config)
|
||||
|
||||
# Write the converted lay down config file to yaml file
|
||||
temp_dir = Path(tempfile.gettempdir())
|
||||
converted_legacy_lay_down_path = temp_dir / legacy_file.replace("legacy_", "")
|
||||
with open(converted_legacy_lay_down_path, "w") as file:
|
||||
yaml.dump(converted_lay_down_config, file)
|
||||
|
||||
# Load the legacy training config yaml file and covvert it to the new format
|
||||
converted_legacy_training_config = training_config.load(
|
||||
TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml", legacy_file=True
|
||||
)
|
||||
|
||||
# Write the converted training config file to yaml file
|
||||
converted_legacy_training_path = temp_dir / "training_config.yaml"
|
||||
with open(converted_legacy_training_path, "w") as file:
|
||||
yaml.dump(converted_legacy_training_config.to_dict(json_serializable=True), file)
|
||||
legacy_training_config_path = TEST_CONFIG_ROOT / "legacy_conversion" / "legacy_training_config.yaml"
|
||||
legacy_lay_down_config_path = TEST_CONFIG_ROOT / "legacy_conversion" / legacy_file
|
||||
|
||||
# Run a PrimAITE session using the paths of both the converted training and lay down config files
|
||||
run(converted_legacy_training_path, converted_legacy_lay_down_path)
|
||||
run(
|
||||
legacy_training_config_path,
|
||||
legacy_lay_down_config_path,
|
||||
legacy_training_config=True,
|
||||
legacy_lay_down_config=True,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user