Apply suggestions from code review.

This commit is contained in:
Marek Wolan
2023-06-07 15:25:11 +01:00
parent 89cea9289b
commit 9417cd85ab
14 changed files with 338 additions and 308 deletions

View File

@@ -9,6 +9,7 @@ class ConfigValuesMain(object):
"""Init."""
# Generic
self.agent_identifier = "" # the agent in use
self.observation_config = None # observation space config
self.num_episodes = 0 # number of episodes to train over
self.num_steps = 0 # number of steps in an episode
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only

View File

@@ -1,7 +1,7 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, List, Tuple, Union
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
import numpy as np
from gym import spaces
@@ -56,9 +56,9 @@ class NodeLinkTable(AbstractObservationComponent):
``(12, 7)``
"""
_FIXED_PARAMETERS = 4
_MAX_VAL = 1_000_000
_DATA_TYPE = np.int64
_FIXED_PARAMETERS: int = 4
_MAX_VAL: int = 1_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
super().__init__(env)
@@ -159,7 +159,7 @@ class NodeStatuses(AbstractObservationComponent):
:type env: Primaite
"""
_DATA_TYPE = np.int64
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
super().__init__(env)
@@ -231,7 +231,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
:type quantisation_levels: int, optional
"""
_DATA_TYPE = np.int64
_DATA_TYPE: type = np.int64
def __init__(
self,
@@ -239,7 +239,14 @@ class LinkTrafficLevels(AbstractObservationComponent):
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
assert quantisation_levels >= 3
if quantisation_levels < 3:
_msg = (
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
f"Resetting to default value (5)"
)
_LOGGER.warning(_msg)
quantisation_levels = 5
super().__init__(env)
@@ -296,7 +303,7 @@ class ObservationsHandler:
Each component can also define further parameters to make them more flexible.
"""
registry = {
_REGISTRY: Final[Dict[str, type]] = {
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
@@ -384,7 +391,7 @@ class ObservationsHandler:
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
comp_class = cls.registry[comp_type]
comp_class = cls._REGISTRY[comp_type]
# Create the component with options from the YAML
options = component_cfg.get("options") or {}

View File

@@ -48,10 +48,10 @@ class Primaite(Env):
"""PRIMmary AI Training Evironment (Primaite) class."""
# Action Space contants
ACTION_SPACE_NODE_PROPERTY_VALUES = 5
ACTION_SPACE_NODE_ACTION_VALUES = 4
ACTION_SPACE_ACL_ACTION_VALUES = 3
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5
ACTION_SPACE_NODE_ACTION_VALUES: int = 4
ACTION_SPACE_ACL_ACTION_VALUES: int = 3
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__(self, _config_values, _transaction_list):
"""
@@ -148,6 +148,8 @@ class Primaite(Env):
# stores the observation config from the yaml, default is NODE_LINK_TABLE
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
if self.config_values.observation_config is not None:
self.obs_config = self.config_values.observation_config
# Observation Handler manages the user-configurable observation space.
# It will be initialised later.
@@ -690,9 +692,6 @@ class Primaite(Env):
elif item["itemType"] == "ACTIONS":
# Get the action information
self.get_action_info(item)
elif item["itemType"] == "OBSERVATION_SPACE":
# Get the observation information
self.save_obs_config(item)
elif item["itemType"] == "STEPS":
# Get the steps information
self.get_steps_info(item)

View File

@@ -163,6 +163,10 @@ def load_config_values():
try:
# Generic
config_values.agent_identifier = config_data["agentIdentifier"]
if "observationSpace" in config_data:
config_values.observation_config = config_data["observationSpace"]
else:
config_values.observation_config = None
config_values.num_episodes = int(config_data["numEpisodes"])
config_values.time_delay = int(config_data["timeDelay"])
config_values.config_filename_use_case = (