Fix trying to init obs before building network
This commit is contained in:
@@ -1,17 +1,22 @@
|
||||
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Tuple
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gym import spaces
|
||||
|
||||
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from primaite.nodes.active_node import ActiveNode
|
||||
from primaite.nodes.service_node import ServiceNode
|
||||
|
||||
# This dependency is only needed for type hints,
|
||||
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
||||
# Therefore, this avoids circular dependency problem.
|
||||
if TYPE_CHECKING:
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -19,9 +24,9 @@ class AbstractObservationComponent(ABC):
|
||||
"""Represents a part of the PrimAITE observation space."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: Primaite):
|
||||
def __init__(self, env: "Primaite"):
|
||||
_LOGGER.info(f"Initialising {self} observation component")
|
||||
self.env: Primaite = env
|
||||
self.env: "Primaite" = env
|
||||
self.space: spaces.Space
|
||||
self.current_observation: np.ndarray # type might be too restrictive?
|
||||
return NotImplemented
|
||||
@@ -51,7 +56,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
|
||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
||||
``(12, 7)``
|
||||
#todo: clean up description
|
||||
#TODO: clean up description
|
||||
|
||||
"""
|
||||
|
||||
@@ -59,7 +64,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
_MAX_VAL = 1_000_000
|
||||
_DATA_TYPE = np.int64
|
||||
|
||||
def __init__(self, env: Primaite):
|
||||
def __init__(self, env: "Primaite"):
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
@@ -76,16 +81,16 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
)
|
||||
|
||||
# 3. Initialise Observation with zeroes
|
||||
self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE)
|
||||
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||
|
||||
def update_obs(self):
|
||||
def update(self):
|
||||
"""Update the observation.
|
||||
|
||||
Update the environment's observation state based on the current status of nodes and links.
|
||||
|
||||
The structure of the observation space is described in :func:`~_init_box_observations`
|
||||
This function can only be called if the observation space setting is set to BOX.
|
||||
todo: complete description..
|
||||
TODO: complete description..
|
||||
"""
|
||||
item_index = 0
|
||||
nodes = self.env.nodes
|
||||
@@ -136,7 +141,7 @@ class NodeLinkTable(AbstractObservationComponent):
|
||||
|
||||
|
||||
class NodeStatuses(AbstractObservationComponent):
|
||||
"""todo: complete description.
|
||||
"""TODO: complete description.
|
||||
|
||||
This will create the observation space with node observations followed by link observations.
|
||||
Each node has 3 elements in the observation space plus 1 per service, more specifically:
|
||||
@@ -148,7 +153,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
_DATA_TYPE = np.int64
|
||||
|
||||
def __init__(self, env: Primaite):
|
||||
def __init__(self, env: "Primaite"):
|
||||
super().__init__(env)
|
||||
|
||||
# 1. Define the shape of your observation space component
|
||||
@@ -166,8 +171,8 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
def update_obs(self):
|
||||
"""todo: complete description.
|
||||
def update(self):
|
||||
"""TODO: complete description.
|
||||
|
||||
Update the environment's observation state based on the current status of nodes and links.
|
||||
|
||||
@@ -196,7 +201,7 @@ class NodeStatuses(AbstractObservationComponent):
|
||||
|
||||
|
||||
class LinkTrafficLevels(AbstractObservationComponent):
|
||||
"""todo: complete description.
|
||||
"""TODO: complete description.
|
||||
|
||||
Each link has one element in the observation space, corresponding to the traffic load,
|
||||
it can take the following values:
|
||||
@@ -211,7 +216,7 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
env: Primaite,
|
||||
env: "Primaite",
|
||||
combine_service_traffic: bool = False,
|
||||
quantisation_levels: int = 5,
|
||||
):
|
||||
@@ -234,8 +239,8 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
def update_obs(self):
|
||||
"""todo: complete description."""
|
||||
def update(self):
|
||||
"""TODO: complete description."""
|
||||
obs = []
|
||||
for _, link in self.env.links.items():
|
||||
bandwidth = link.bandwidth
|
||||
@@ -262,15 +267,14 @@ class LinkTrafficLevels(AbstractObservationComponent):
|
||||
class ObservationsHandler:
|
||||
"""Component-based observation space handler."""
|
||||
|
||||
class registry(Enum):
|
||||
"""todo: complete description."""
|
||||
|
||||
NODE_LINK_TABLE: NodeLinkTable
|
||||
NODE_STATUSES: NodeStatuses
|
||||
LINK_TRAFFIC_LEVELS: LinkTrafficLevels
|
||||
registry = {
|
||||
"NODE_LINK_TABLE": NodeLinkTable,
|
||||
"NODE_STATUSES": NodeStatuses,
|
||||
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
"""todo: complete description."""
|
||||
"""TODO: complete description."""
|
||||
"""Initialise the handler without any components yet. They"""
|
||||
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||
self.space: spaces.Space
|
||||
@@ -279,33 +283,33 @@ class ObservationsHandler:
|
||||
# self.registry.LINK_TRAFFIC_LEVELS
|
||||
|
||||
def update_obs(self):
|
||||
"""todo: complete description."""
|
||||
"""TODO: complete description."""
|
||||
current_obs = []
|
||||
for obs in self.registered_obs_components:
|
||||
obs.update_obs()
|
||||
obs.update()
|
||||
current_obs.append(obs.current_observation)
|
||||
self.current_observation = tuple(current_obs)
|
||||
|
||||
def register(self, obs_component: AbstractObservationComponent):
|
||||
"""todo: complete description."""
|
||||
"""TODO: complete description."""
|
||||
self.registered_obs_components.append(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def deregister(self, obs_component: AbstractObservationComponent):
|
||||
"""todo: complete description."""
|
||||
"""TODO: complete description."""
|
||||
self.registered_obs_components.remove(obs_component)
|
||||
self.update_space()
|
||||
|
||||
def update_space(self):
|
||||
"""todo: complete description."""
|
||||
"""TODO: complete description."""
|
||||
component_spaces = []
|
||||
for obs_comp in self.registered_obs_components:
|
||||
component_spaces.append(obs_comp.space)
|
||||
self.space = spaces.Tuple(component_spaces)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, obs_space_config):
|
||||
"""todo: complete description.
|
||||
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||
"""TODO: complete description.
|
||||
|
||||
This method parses config items related to the observation space, then
|
||||
creates the necessary components and adds them to the observation handler.
|
||||
@@ -316,11 +320,13 @@ class ObservationsHandler:
|
||||
for component_cfg in obs_space_config["components"]:
|
||||
# Figure out which class can instantiate the desired component
|
||||
comp_type = component_cfg["name"]
|
||||
comp_class = cls.registry[comp_type].value
|
||||
comp_class = cls.registry[comp_type]
|
||||
|
||||
# Create the component with options from the YAML
|
||||
component = comp_class(**component_cfg["options"])
|
||||
options = component_cfg.get("options") or {}
|
||||
component = comp_class(env, **options)
|
||||
|
||||
handler.register(component)
|
||||
|
||||
handler.update_obs()
|
||||
return handler
|
||||
|
||||
@@ -149,7 +149,8 @@ class Primaite(Env):
|
||||
# The action type
|
||||
self.action_type = 0
|
||||
|
||||
# todo: proper description here
|
||||
# TODO: proper description here
|
||||
self.obs_config: dict
|
||||
self.obs_handler: ObservationsHandler
|
||||
|
||||
# Open the config file and build the environment laydown
|
||||
@@ -161,10 +162,6 @@ class Primaite(Env):
|
||||
_LOGGER.error("Could not load the environment configuration")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
|
||||
# If it doesn't exist after parsing config, create default obs space.
|
||||
if self.get("obs_handler") is None:
|
||||
self.configure_obs_space()
|
||||
|
||||
# Store the node objects as node attributes
|
||||
# (This is so we can access them as objects)
|
||||
for node in self.network:
|
||||
@@ -195,6 +192,10 @@ class Primaite(Env):
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
print("Could not save network diagram")
|
||||
|
||||
# # If it doesn't exist after parsing config, create default obs space.
|
||||
# if getattr(self, "obs_handler", None) is None:
|
||||
# self.configure_obs_space()
|
||||
|
||||
# Initiate observation space
|
||||
self.observation_space, self.env_obs = self.init_observations()
|
||||
|
||||
@@ -646,13 +647,22 @@ class Primaite(Env):
|
||||
pass
|
||||
|
||||
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
||||
"""todo: write docstring."""
|
||||
"""TODO: write docstring."""
|
||||
if getattr(self, "obs_config", None) is None:
|
||||
self.obs_config = {
|
||||
"components": [
|
||||
{"name": "NODE_LINK_TABLE"},
|
||||
]
|
||||
}
|
||||
|
||||
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
||||
|
||||
return self.obs_handler.space, self.obs_handler.current_observation
|
||||
|
||||
def update_environent_obs(self):
|
||||
"""Updates the observation space based on the node and link status.
|
||||
|
||||
todo: better docstring
|
||||
TODO: better docstring
|
||||
"""
|
||||
self.obs_handler.update_obs()
|
||||
self.env_obs = self.obs_handler.current_observation
|
||||
@@ -692,7 +702,7 @@ class Primaite(Env):
|
||||
self.get_action_info(item)
|
||||
elif item["itemType"] == "OBSERVATION_SPACE":
|
||||
# Get the observation information
|
||||
self.configure_obs_space(item)
|
||||
self.save_obs_config(item)
|
||||
elif item["itemType"] == "STEPS":
|
||||
# Get the steps information
|
||||
self.get_steps_info(item)
|
||||
@@ -1025,16 +1035,9 @@ class Primaite(Env):
|
||||
"""
|
||||
self.action_type = ActionType[action_info["type"]]
|
||||
|
||||
def configure_obs_space(self, observation_config: Optional[Dict] = None):
|
||||
"""todo: better docstring."""
|
||||
if observation_config is None:
|
||||
observation_config = {
|
||||
"components": [
|
||||
{"name": "NODE_LINK_TABLE"},
|
||||
]
|
||||
}
|
||||
|
||||
self.obs_handler = ObservationsHandler[observation_config]
|
||||
def save_obs_config(self, obs_config: Optional[Dict] = None):
|
||||
"""TODO: better docstring."""
|
||||
self.obs_config = obs_config
|
||||
|
||||
def get_steps_info(self, steps_info):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user