diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 94c2730f..a1b0d9ac 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -1,17 +1,22 @@ """Module for handling configurable observation spaces in PrimAITE.""" import logging from abc import ABC, abstractmethod -from enum import Enum -from typing import List, Tuple +from typing import TYPE_CHECKING, List, Tuple import numpy as np from gym import spaces from primaite.common.enums import FileSystemState, HardwareState, SoftwareState -from primaite.environment.primaite_env import Primaite from primaite.nodes.active_node import ActiveNode from primaite.nodes.service_node import ServiceNode +# This dependency is only needed for type hints, +# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking +# Therefore, this avoids circular dependency problem. +if TYPE_CHECKING: + from primaite.environment.primaite_env import Primaite + + _LOGGER = logging.getLogger(__name__) @@ -19,9 +24,9 @@ class AbstractObservationComponent(ABC): """Represents a part of the PrimAITE observation space.""" @abstractmethod - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): _LOGGER.info(f"Initialising {self} observation component") - self.env: Primaite = env + self.env: "Primaite" = env self.space: spaces.Space self.current_observation: np.ndarray # type might be too restrictive? return NotImplemented @@ -51,7 +56,7 @@ class NodeLinkTable(AbstractObservationComponent): For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be ``(12, 7)`` - #todo: clean up description + #TODO: clean up description """ @@ -59,7 +64,7 @@ class NodeLinkTable(AbstractObservationComponent): _MAX_VAL = 1_000_000 _DATA_TYPE = np.int64 - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): super().__init__(env) # 1. Define the shape of your observation space component @@ -76,16 +81,16 @@ class NodeLinkTable(AbstractObservationComponent): ) # 3. Initialise Observation with zeroes - self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE) + self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) - def update_obs(self): + def update(self): """Update the observation. Update the environment's observation state based on the current status of nodes and links. The structure of the observation space is described in :func:`~_init_box_observations` This function can only be called if the observation space setting is set to BOX. - todo: complete description.. + TODO: complete description.. """ item_index = 0 nodes = self.env.nodes @@ -136,7 +141,7 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """todo: complete description. + """TODO: complete description. This will create the observation space with node observations followed by link observations. Each node has 3 elements in the observation space plus 1 per service, more specifically: @@ -148,7 +153,7 @@ class NodeStatuses(AbstractObservationComponent): _DATA_TYPE = np.int64 - def __init__(self, env: Primaite): + def __init__(self, env: "Primaite"): super().__init__(env) # 1. Define the shape of your observation space component @@ -166,8 +171,8 @@ class NodeStatuses(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - def update_obs(self): - """todo: complete description. + def update(self): + """TODO: complete description. Update the environment's observation state based on the current status of nodes and links. @@ -196,7 +201,7 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """todo: complete description. + """TODO: complete description. Each link has one element in the observation space, corresponding to the traffic load, it can take the following values: @@ -211,7 +216,7 @@ class LinkTrafficLevels(AbstractObservationComponent): def __init__( self, - env: Primaite, + env: "Primaite", combine_service_traffic: bool = False, quantisation_levels: int = 5, ): @@ -234,8 +239,8 @@ class LinkTrafficLevels(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - def update_obs(self): - """todo: complete description.""" + def update(self): + """TODO: complete description.""" obs = [] for _, link in self.env.links.items(): bandwidth = link.bandwidth @@ -262,15 +267,14 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: """Component-based observation space handler.""" - class registry(Enum): - """todo: complete description.""" - - NODE_LINK_TABLE: NodeLinkTable - NODE_STATUSES: NodeStatuses - LINK_TRAFFIC_LEVELS: LinkTrafficLevels + registry = { + "NODE_LINK_TABLE": NodeLinkTable, + "NODE_STATUSES": NodeStatuses, + "LINK_TRAFFIC_LEVELS": LinkTrafficLevels, + } def __init__(self): - """todo: complete description.""" + """TODO: complete description.""" """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space @@ -279,33 +283,33 @@ class ObservationsHandler: # self.registry.LINK_TRAFFIC_LEVELS def update_obs(self): - """todo: complete description.""" + """TODO: complete description.""" current_obs = [] for obs in self.registered_obs_components: - obs.update_obs() + obs.update() current_obs.append(obs.current_observation) self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): - """todo: complete description.""" + """TODO: complete description.""" self.registered_obs_components.append(obs_component) self.update_space() def deregister(self, obs_component: AbstractObservationComponent): - """todo: complete description.""" + """TODO: complete description.""" self.registered_obs_components.remove(obs_component) self.update_space() def update_space(self): - """todo: complete description.""" + """TODO: complete description.""" component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) self.space = spaces.Tuple(component_spaces) @classmethod - def from_config(cls, obs_space_config): - """todo: complete description. + def from_config(cls, env: "Primaite", obs_space_config: dict): + """TODO: complete description. This method parses config items related to the observation space, then creates the necessary components and adds them to the observation handler. @@ -316,11 +320,13 @@ class ObservationsHandler: for component_cfg in obs_space_config["components"]: # Figure out which class can instantiate the desired component comp_type = component_cfg["name"] - comp_class = cls.registry[comp_type].value + comp_class = cls.registry[comp_type] # Create the component with options from the YAML - component = comp_class(**component_cfg["options"]) + options = component_cfg.get("options") or {} + component = comp_class(env, **options) handler.register(component) + handler.update_obs() return handler diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index afa04060..0107920f 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -149,7 +149,8 @@ class Primaite(Env): # The action type self.action_type = 0 - # todo: proper description here + # TODO: proper description here + self.obs_config: dict self.obs_handler: ObservationsHandler # Open the config file and build the environment laydown @@ -161,10 +162,6 @@ class Primaite(Env): _LOGGER.error("Could not load the environment configuration") _LOGGER.error("Exception occured", exc_info=True) - # If it doesn't exist after parsing config, create default obs space. - if self.get("obs_handler") is None: - self.configure_obs_space() - # Store the node objects as node attributes # (This is so we can access them as objects) for node in self.network: @@ -195,6 +192,10 @@ class Primaite(Env): _LOGGER.error("Exception occured", exc_info=True) print("Could not save network diagram") + # # If it doesn't exist after parsing config, create default obs space. + # if getattr(self, "obs_handler", None) is None: + # self.configure_obs_space() + # Initiate observation space self.observation_space, self.env_obs = self.init_observations() @@ -646,13 +647,22 @@ class Primaite(Env): pass def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """todo: write docstring.""" + """TODO: write docstring.""" + if getattr(self, "obs_config", None) is None: + self.obs_config = { + "components": [ + {"name": "NODE_LINK_TABLE"}, + ] + } + + self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) + return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): """Updates the observation space based on the node and link status. - todo: better docstring + TODO: better docstring """ self.obs_handler.update_obs() self.env_obs = self.obs_handler.current_observation @@ -692,7 +702,7 @@ class Primaite(Env): self.get_action_info(item) elif item["itemType"] == "OBSERVATION_SPACE": # Get the observation information - self.configure_obs_space(item) + self.save_obs_config(item) elif item["itemType"] == "STEPS": # Get the steps information self.get_steps_info(item) @@ -1025,16 +1035,9 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def configure_obs_space(self, observation_config: Optional[Dict] = None): - """todo: better docstring.""" - if observation_config is None: - observation_config = { - "components": [ - {"name": "NODE_LINK_TABLE"}, - ] - } - - self.obs_handler = ObservationsHandler[observation_config] + def save_obs_config(self, obs_config: Optional[Dict] = None): + """TODO: better docstring.""" + self.obs_config = obs_config def get_steps_info(self, steps_info): """