Apply suggestions from code review.
This commit is contained in:
@@ -9,6 +9,7 @@ class ConfigValuesMain(object):
|
||||
"""Init."""
|
||||
# Generic
|
||||
self.agent_identifier = "" # the agent in use
|
||||
self.observation_config = None # observation space config
|
||||
self.num_episodes = 0 # number of episodes to train over
|
||||
self.num_steps = 0 # number of steps in an episode
|
||||
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
@@ -56,9 +56,9 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
``(12, 7)``
|
||||
"""
|
||||
|
||||
_FIXED_PARAMETERS = 4
|
||||
_MAX_VAL = 1_000_000
|
||||
_DATA_TYPE = np.int64
|
||||
_FIXED_PARAMETERS: int = 4
|
||||
_MAX_VAL: int = 1_000_000
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
super().__init__(env)
|
||||
@@ -159,7 +159,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
:type env: Primaite
|
||||
"""
|
||||
|
||||
_DATA_TYPE = np.int64
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(self, env: "Primaite"):
|
||||
super().__init__(env)
|
||||
@@ -231,7 +231,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
:type quantisation_levels: int, optional
|
||||
"""
|
||||
|
||||
_DATA_TYPE = np.int64
|
||||
_DATA_TYPE: type = np.int64
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -239,7 +239,14 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
):
|
||||
assert quantisation_levels >= 3
|
||||
if quantisation_levels < 3:
|
||||
_msg = (
|
||||
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
|
||||
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
|
||||
f"Resetting to default value (5)"
|
||||
)
|
||||
_LOGGER.warning(_msg)
|
||||
quantisation_levels = 5
|
||||
|
||||
super().__init__(env)
|
||||
|
||||
@@ -296,7 +303,7 @@ class ObservationsHandler:
|
||||
Each component can also define further parameters to make them more flexible.
|
||||
"""
|
||||
|
||||
registry = {
|
||||
_REGISTRY: Final[Dict[str, type]] = {
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
@@ -384,7 +391,7 @@ class ObservationsHandler:
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
comp_class = cls.registry[comp_type]
|
||||
comp_class = cls._REGISTRY[comp_type]
|
||||
|
||||
# Create the component with options from the YAML
|
||||
options = component_cfg.get("options") or {}
|
||||
|
||||
@@ -48,10 +48,10 @@ class Primaite(Env):
|
||||
"""PRIMmary AI Training Evironment (Primaite) class."""
|
||||
|
||||
# Action Space contants
|
||||
ACTION_SPACE_NODE_PROPERTY_VALUES = 5
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 4
|
||||
ACTION_SPACE_ACL_ACTION_VALUES = 3
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
|
||||
ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5
|
||||
ACTION_SPACE_NODE_ACTION_VALUES: int = 4
|
||||
ACTION_SPACE_ACL_ACTION_VALUES: int = 3
|
||||
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
||||
|
||||
def __init__(self, _config_values, _transaction_list):
|
||||
"""
|
||||
@@ -148,6 +148,8 @@ class Primaite(Env):
|
||||
|
||||
# stores the observation config from the yaml, default is NODE_LINK_TABLE
|
||||
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||
if self.config_values.observation_config is not None:
|
||||
self.obs_config = self.config_values.observation_config
|
||||
|
||||
# Observation Handler manages the user-configurable observation space.
|
||||
# It will be initialised later.
|
||||
@@ -690,9 +692,6 @@ class Primaite(Env):
|
||||
elif item["itemType"] == "ACTIONS":
|
||||
# Get the action information
|
||||
self.get_action_info(item)
|
||||
elif item["itemType"] == "OBSERVATION_SPACE":
|
||||
# Get the observation information
|
||||
self.save_obs_config(item)
|
||||
elif item["itemType"] == "STEPS":
|
||||
# Get the steps information
|
||||
self.get_steps_info(item)
|
||||
|
||||
@@ -163,6 +163,10 @@ def load_config_values():
|
||||
try:
|
||||
# Generic
|
||||
config_values.agent_identifier = config_data["agentIdentifier"]
|
||||
if "observationSpace" in config_data:
|
||||
config_values.observation_config = config_data["observationSpace"]
|
||||
else:
|
||||
config_values.observation_config = None
|
||||
config_values.num_episodes = int(config_data["numEpisodes"])
|
||||
config_values.time_delay = int(config_data["timeDelay"])
|
||||
config_values.config_filename_use_case = (
|
||||
|
||||
Reference in New Issue
Block a user