Add docstrings to new observation code

This commit is contained in:
Marek Wolan
2023-06-01 21:28:38 +01:00
parent d473794963
commit 084112a2e4
2 changed files with 109 additions and 49 deletions

View File

@@ -33,18 +33,16 @@ class AbstractObservationComponent(ABC):
@abstractmethod
def update(self):
"""Look at the environment and update the current observation value."""
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes/links as rows and hardware/software status as cols.
Initialise the observation space with the BOX option chosen.
"""Table with nodes and links as rows and hardware/software status as cols.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
Columns are as follows:
The number of columns is 4 plus one per service. They are:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
@@ -56,8 +54,6 @@ class NodeLinkTable(AbstractObservationComponent):
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
#TODO: clean up description
"""
_FIXED_PARAMETERS = 4
@@ -84,13 +80,9 @@ class NodeLinkTable(AbstractObservationComponent):
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
def update(self):
"""Update the observation.
"""Update the observation based on current environment state.
Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_box_observations`
This function can only be called if the observation space setting is set to BOX.
TODO: complete description..
The structure of the observation space is described in :class:`.NodeLinkTable`
"""
item_index = 0
nodes = self.env.nodes
@@ -141,14 +133,30 @@ class NodeLinkTable(AbstractObservationComponent):
class NodeStatuses(AbstractObservationComponent):
"""TODO: complete description.
"""Flat list of nodes' hardware, OS, file system, and service states.
This will create the observation space with node observations followed by link observations.
Each node has 3 elements in the observation space plus 1 per service, more specifically:
* hardware state
* operating system state
* file system state
* service states (one per service)
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each node has 3 elements plus 1 per service. It will have the following structure:
.. code-block::
[
node1 hardware state,
node1 OS state,
node1 file system state,
node1 service1 state,
node1 service2 state,
node1 serviceN state (one for each service),
node2 hardware state,
node2 OS state,
node2 file system state,
node2 service1 state,
node2 service2 state,
node2 serviceN state (one for each service),
...
]
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
_DATA_TYPE = np.int64
@@ -172,14 +180,9 @@ class NodeStatuses(AbstractObservationComponent):
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update(self):
"""TODO: complete description.
Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_multidiscrete_observations`
This function can only be called if the observation space setting is set to MULTIDISCRETE.
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeStatuses`
"""
obs = []
for _, node in self.env.nodes.items():
@@ -201,15 +204,28 @@ class NodeStatuses(AbstractObservationComponent):
class LinkTrafficLevels(AbstractObservationComponent):
"""TODO: complete description.
"""Flat list of traffic levels encoded into banded categories.
Each link has one element in the observation space, corresponding to the traffic load,
it can take the following values:
For each link, total traffic or traffic per service is encoded into a categorical value.
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
.. note::
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
:param env: The environment that forms the basis of the observations
:type env: Primaite
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
defaults to False
:type combine_service_traffic: bool, optional
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value ,
defaults to 5
:type quantisation_levels: int, optional
"""
_DATA_TYPE = np.int64
@@ -220,7 +236,10 @@ class LinkTrafficLevels(AbstractObservationComponent):
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
assert quantisation_levels >= 3
super().__init__(env)
self._combine_service_traffic: bool = combine_service_traffic
self._quantisation_levels: int = quantisation_levels
self._entries_per_link: int = 1
@@ -240,7 +259,10 @@ class LinkTrafficLevels(AbstractObservationComponent):
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update(self):
"""TODO: complete description."""
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.LinkTrafficLevels`
"""
obs = []
for _, link in self.env.links.items():
bandwidth = link.bandwidth
@@ -265,7 +287,11 @@ class LinkTrafficLevels(AbstractObservationComponent):
class ObservationsHandler:
"""Component-based observation space handler."""
"""Component-based observation space handler.
This allows users to configure observation spaces by mixing and matching components.
Each component can also define further parameters to make them more flexible.
"""
registry = {
"NODE_LINK_TABLE": NodeLinkTable,
@@ -274,8 +300,6 @@ class ObservationsHandler:
}
def __init__(self):
"""TODO: complete description."""
"""Initialise the handler without any components yet. They"""
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
@@ -283,7 +307,7 @@ class ObservationsHandler:
# self.registry.LINK_TRAFFIC_LEVELS
def update_obs(self):
"""TODO: complete description."""
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
obs.update()
@@ -296,17 +320,26 @@ class ObservationsHandler:
self.current_observation = tuple(current_obs)
def register(self, obs_component: AbstractObservationComponent):
"""TODO: complete description."""
"""Add a component for this handler to track.
:param obs_component: The component to add.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
"""TODO: complete description."""
"""Remove a component from this handler.
:param obs_component: Which component to remove. It must exist within this object's
``registered_obs_components`` attribute.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
"""TODO: complete description."""
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
@@ -319,10 +352,28 @@ class ObservationsHandler:
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
"""TODO: complete description.
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
This method parses config items related to the observation space, then
creates the necessary components and adds them to the observation handler.
The expected format for the config dictionary is:
..code-block::python
config = {
components: [
{
"name": "<COMPONENT1_NAME>"
},
{
"name": "<COMPONENT2_NAME>"
"options": {"opt1": val1, "opt2": val2}
},
{
...
},
]
}
:return: Observation handler
:rtype: primaite.environment.observations.ObservationsHandler
"""
# Instantiate the handler
handler = cls()

View File

@@ -6,7 +6,7 @@ import csv
import logging
import os.path
from datetime import datetime
from typing import Dict, Optional, Tuple
from typing import Dict, Tuple
import networkx as nx
import numpy as np
@@ -643,16 +643,17 @@ class Primaite(Env):
pass
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""TODO: write docstring."""
"""Create the environment's observation handler.
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
"""
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
"""Updates the observation space based on the node and link status.
TODO: better docstring
"""
"""Updates the observation space based on the node and link status."""
self.obs_handler.update_obs()
self.env_obs = self.obs_handler.current_observation
@@ -1024,8 +1025,16 @@ class Primaite(Env):
"""
self.action_type = ActionType[action_info["type"]]
def save_obs_config(self, obs_config: Optional[Dict] = None):
"""TODO: better docstring."""
def save_obs_config(self, obs_config: dict):
"""Cache the config for the observation space.
This is necessary as the observation space can't be built while reading the config,
it must be done after all the nodes, links, and services have been initialised.
:param obs_config: Parsed config relating to the observation space. The format is described in
:py:meth:`primaite.environment.observations.ObservationsHandler.from_config`
:type obs_config: dict
"""
self.obs_config = obs_config
def get_steps_info(self, steps_info):