Fix trying to init obs before building network

This commit is contained in:
Marek Wolan
2023-06-01 17:42:35 +01:00
parent 2b25573378
commit 7041b79d2a
2 changed files with 61 additions and 52 deletions

View File

@@ -1,17 +1,22 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Tuple
from typing import TYPE_CHECKING, List, Tuple
import numpy as np
from gym import spaces
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.environment.primaite_env import Primaite
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
# This dependency is only needed for type hints,
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
# Therefore, this avoids circular dependency problem.
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER = logging.getLogger(__name__)
@@ -19,9 +24,9 @@ class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: Primaite):
def __init__(self, env: "Primaite"):
_LOGGER.info(f"Initialising {self} observation component")
self.env: Primaite = env
self.env: "Primaite" = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
return NotImplemented
@@ -51,7 +56,7 @@ class NodeLinkTable(AbstractObservationComponent):
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
#todo: clean up description
#TODO: clean up description
"""
@@ -59,7 +64,7 @@ class NodeLinkTable(AbstractObservationComponent):
_MAX_VAL = 1_000_000
_DATA_TYPE = np.int64
def __init__(self, env: Primaite):
def __init__(self, env: "Primaite"):
super().__init__(env)
# 1. Define the shape of your observation space component
@@ -76,16 +81,16 @@ class NodeLinkTable(AbstractObservationComponent):
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE)
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
def update_obs(self):
def update(self):
"""Update the observation.
Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_box_observations`
This function can only be called if the observation space setting is set to BOX.
todo: complete description..
TODO: complete description..
"""
item_index = 0
nodes = self.env.nodes
@@ -136,7 +141,7 @@ class NodeLinkTable(AbstractObservationComponent):
class NodeStatuses(AbstractObservationComponent):
"""todo: complete description.
"""TODO: complete description.
This will create the observation space with node observations followed by link observations.
Each node has 3 elements in the observation space plus 1 per service, more specifically:
@@ -148,7 +153,7 @@ class NodeStatuses(AbstractObservationComponent):
_DATA_TYPE = np.int64
def __init__(self, env: Primaite):
def __init__(self, env: "Primaite"):
super().__init__(env)
# 1. Define the shape of your observation space component
@@ -166,8 +171,8 @@ class NodeStatuses(AbstractObservationComponent):
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update_obs(self):
"""todo: complete description.
def update(self):
"""TODO: complete description.
Update the environment's observation state based on the current status of nodes and links.
@@ -196,7 +201,7 @@ class NodeStatuses(AbstractObservationComponent):
class LinkTrafficLevels(AbstractObservationComponent):
"""todo: complete description.
"""TODO: complete description.
Each link has one element in the observation space, corresponding to the traffic load,
it can take the following values:
@@ -211,7 +216,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
def __init__(
self,
env: Primaite,
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
@@ -234,8 +239,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update_obs(self):
"""todo: complete description."""
def update(self):
"""TODO: complete description."""
obs = []
for _, link in self.env.links.items():
bandwidth = link.bandwidth
@@ -262,15 +267,14 @@ class LinkTrafficLevels(AbstractObservationComponent):
class ObservationsHandler:
"""Component-based observation space handler."""
class registry(Enum):
"""todo: complete description."""
NODE_LINK_TABLE: NodeLinkTable
NODE_STATUSES: NodeStatuses
LINK_TRAFFIC_LEVELS: LinkTrafficLevels
registry = {
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
}
def __init__(self):
"""todo: complete description."""
"""TODO: complete description."""
"""Initialise the handler without any components yet. They"""
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
@@ -279,33 +283,33 @@ class ObservationsHandler:
# self.registry.LINK_TRAFFIC_LEVELS
def update_obs(self):
"""todo: complete description."""
"""TODO: complete description."""
current_obs = []
for obs in self.registered_obs_components:
obs.update_obs()
obs.update()
current_obs.append(obs.current_observation)
self.current_observation = tuple(current_obs)
def register(self, obs_component: AbstractObservationComponent):
"""todo: complete description."""
"""TODO: complete description."""
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
"""todo: complete description."""
"""TODO: complete description."""
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
"""todo: complete description."""
"""TODO: complete description."""
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
self.space = spaces.Tuple(component_spaces)
@classmethod
def from_config(cls, obs_space_config):
"""todo: complete description.
def from_config(cls, env: "Primaite", obs_space_config: dict):
"""TODO: complete description.
This method parses config items related to the observation space, then
creates the necessary components and adds them to the observation handler.
@@ -316,11 +320,13 @@ class ObservationsHandler:
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
comp_class = cls.registry[comp_type].value
comp_class = cls.registry[comp_type]
# Create the component with options from the YAML
component = comp_class(**component_cfg["options"])
options = component_cfg.get("options") or {}
component = comp_class(env, **options)
handler.register(component)
handler.update_obs()
return handler

View File

@@ -149,7 +149,8 @@ class Primaite(Env):
# The action type
self.action_type = 0
# todo: proper description here
# TODO: proper description here
self.obs_config: dict
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
@@ -161,10 +162,6 @@ class Primaite(Env):
_LOGGER.error("Could not load the environment configuration")
_LOGGER.error("Exception occured", exc_info=True)
# If it doesn't exist after parsing config, create default obs space.
if self.get("obs_handler") is None:
self.configure_obs_space()
# Store the node objects as node attributes
# (This is so we can access them as objects)
for node in self.network:
@@ -195,6 +192,10 @@ class Primaite(Env):
_LOGGER.error("Exception occured", exc_info=True)
print("Could not save network diagram")
# # If it doesn't exist after parsing config, create default obs space.
# if getattr(self, "obs_handler", None) is None:
# self.configure_obs_space()
# Initiate observation space
self.observation_space, self.env_obs = self.init_observations()
@@ -646,13 +647,22 @@ class Primaite(Env):
pass
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""todo: write docstring."""
"""TODO: write docstring."""
if getattr(self, "obs_config", None) is None:
self.obs_config = {
"components": [
{"name": "NODE_LINK_TABLE"},
]
}
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
"""Updates the observation space based on the node and link status.
todo: better docstring
TODO: better docstring
"""
self.obs_handler.update_obs()
self.env_obs = self.obs_handler.current_observation
@@ -692,7 +702,7 @@ class Primaite(Env):
self.get_action_info(item)
elif item["itemType"] == "OBSERVATION_SPACE":
# Get the observation information
self.configure_obs_space(item)
self.save_obs_config(item)
elif item["itemType"] == "STEPS":
# Get the steps information
self.get_steps_info(item)
@@ -1025,16 +1035,9 @@ class Primaite(Env):
"""
self.action_type = ActionType[action_info["type"]]
def configure_obs_space(self, observation_config: Optional[Dict] = None):
"""todo: better docstring."""
if observation_config is None:
observation_config = {
"components": [
{"name": "NODE_LINK_TABLE"},
]
}
self.obs_handler = ObservationsHandler[observation_config]
def save_obs_config(self, obs_config: Optional[Dict] = None):
"""TODO: better docstring."""
self.obs_config = obs_config
def get_steps_info(self, steps_info):
"""