#1711 - Added the ability to load legacy lay down config files. Added extensive unit testing and end-to-end testing. Also added the ability to set exactly how many num_train_steps, num_eval_steps, num_train_episodes, and num_eval_episode and used when converting a legacy training config.

This commit is contained in:
Chris McCarthy
2023-07-28 12:53:49 +01:00
parent 378001ff68
commit 7c843d3caa
10 changed files with 1383 additions and 9 deletions

View File

@@ -1,7 +1,7 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from logging import Logger
from pathlib import Path
from typing import Any, Dict, Final, Union
from typing import Any, Dict, Final, List, Union
import yaml
@@ -12,14 +12,43 @@ _LOGGER: Logger = getLogger(__name__)
_EXAMPLE_LAY_DOWN: Final[Path] = PRIMAITE_PATHS.user_config_path / "example_config" / "lay_down"
def convert_legacy_lay_down_config_dict(legacy_config_dict: Dict[str, Any]) -> Dict[str, Any]:
def convert_legacy_lay_down_config(legacy_config: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Convert a legacy lay down config dict to the new format.
Convert a legacy lay down config to the new format.
:param legacy_config_dict: A legacy lay down config dict.
:param legacy_config: A legacy lay down config.
"""
_LOGGER.warning("Legacy lay down config conversion not yet implemented")
return legacy_config_dict
field_conversion_map = {
"itemType": "item_type",
"portsList": "ports_list",
"serviceList": "service_list",
"baseType": "node_class",
"nodeType": "node_type",
"hardwareState": "hardware_state",
"softwareState": "software_state",
"startStep": "start_step",
"endStep": "end_step",
"fileSystemState": "file_system_state",
"ipAddress": "ip_address",
"missionCriticality": "mission_criticality",
}
new_config = []
for item in legacy_config:
if "itemType" in item:
if item["itemType"] in ["ACTIONS", "STEPS"]:
continue
new_dict = {}
for key in item.keys():
conversion_key = field_conversion_map.get(key)
if key == "id" and "itemType" in item:
if item["itemType"] == "NODE":
conversion_key = "node_id"
if conversion_key:
new_dict[conversion_key] = item[key]
else:
new_dict[key] = item[key]
new_config.append(new_dict)
return new_config
def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
@@ -39,7 +68,7 @@ def load(file_path: Union[str, Path], legacy_file: bool = False) -> Dict:
_LOGGER.debug(f"Loading lay down config file: {file_path}")
if legacy_file:
try:
config = convert_legacy_lay_down_config_dict(config)
config = convert_legacy_lay_down_config(config)
except KeyError:
msg = (
f"Failed to convert lay down config file {file_path} "

View File

@@ -314,6 +314,9 @@ def convert_legacy_training_config_dict(
agent_identifier: AgentIdentifier = AgentIdentifier.PPO,
action_type: ActionType = ActionType.ANY,
num_train_steps: int = 256,
num_eval_steps: int = 256,
num_train_episodes: int = 10,
num_eval_episodes: int = 1,
) -> Dict[str, Any]:
"""
Convert a legacy training config dict to the new format.
@@ -325,8 +328,14 @@ def convert_legacy_training_config_dict(
training configs don't have agent_identifier values.
:param action_type: The action space type to set as legacy training configs
don't have action_type values.
:param num_train_steps: The number of steps to set as legacy training configs
:param num_train_steps: The number of train steps to set as legacy training configs
don't have num_train_steps values.
:param num_eval_steps: The number of eval steps to set as legacy training configs
don't have num_eval_steps values.
:param num_train_episodes: The number of train episodes to set as legacy training configs
don't have num_train_episodes values.
:param num_eval_episodes: The number of eval episodes to set as legacy training configs
don't have num_eval_episodes values.
:return: The converted training config dict.
"""
config_dict = {
@@ -334,6 +343,9 @@ def convert_legacy_training_config_dict(
"agent_identifier": agent_identifier.name,
"action_type": action_type.name,
"num_train_steps": num_train_steps,
"num_eval_steps": num_eval_steps,
"num_train_episodes": num_train_episodes,
"num_eval_episodes": num_eval_episodes,
"sb3_output_verbose_level": SB3OutputVerboseLevel.INFO.name,
}
session_type_map = {"TRAINING": "TRAIN", "EVALUATION": "EVAL"}

View File

@@ -1027,7 +1027,7 @@ class Primaite(Env):
acl_rule_destination = item["destination"]
acl_rule_protocol = item["protocol"]
acl_rule_port = item["port"]
acl_rule_position = item["position"]
acl_rule_position = item.get("position")
self.acl.add_rule(
acl_rule_permission,