diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 5bad056c..c4402b69 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -33,18 +33,16 @@ class AbstractObservationComponent(ABC): @abstractmethod def update(self): - """Look at the environment and update the current observation value.""" + """Update the observation based on the current state of the environment.""" self.current_observation = NotImplemented class NodeLinkTable(AbstractObservationComponent): - """Table with nodes/links as rows and hardware/software status as cols. - - Initialise the observation space with the BOX option chosen. + """Table with nodes and links as rows and hardware/software status as cols. This will create the observation space formatted as a table of integers. There is one row per node, followed by one row per link. - Columns are as follows: + The number of columns is 4 plus one per service. They are: * node/link ID * node hardware status / 0 for links * node operating system status (if active/service) / 0 for links @@ -56,8 +54,6 @@ class NodeLinkTable(AbstractObservationComponent): For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be ``(12, 7)`` - #TODO: clean up description - """ _FIXED_PARAMETERS = 4 @@ -84,13 +80,9 @@ class NodeLinkTable(AbstractObservationComponent): self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE) def update(self): - """Update the observation. + """Update the observation based on current environment state. - Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_box_observations` - This function can only be called if the observation space setting is set to BOX. - TODO: complete description.. + The structure of the observation space is described in :class:`.NodeLinkTable` """ item_index = 0 nodes = self.env.nodes @@ -141,14 +133,30 @@ class NodeLinkTable(AbstractObservationComponent): class NodeStatuses(AbstractObservationComponent): - """TODO: complete description. + """Flat list of nodes' hardware, OS, file system, and service states. - This will create the observation space with node observations followed by link observations. - Each node has 3 elements in the observation space plus 1 per service, more specifically: - * hardware state - * operating system state - * file system state - * service states (one per service) + The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by + integers. + Each node has 3 elements plus 1 per service. It will have the following structure: + .. code-block:: + [ + node1 hardware state, + node1 OS state, + node1 file system state, + node1 service1 state, + node1 service2 state, + node1 serviceN state (one for each service), + node2 hardware state, + node2 OS state, + node2 file system state, + node2 service1 state, + node2 service2 state, + node2 serviceN state (one for each service), + ... + ] + + :param env: The environment that forms the basis of the observations + :type env: Primaite """ _DATA_TYPE = np.int64 @@ -172,14 +180,9 @@ class NodeStatuses(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update(self): - """TODO: complete description. - - Update the environment's observation state based on the current status of nodes and links. - - The structure of the observation space is described in :func:`~_init_multidiscrete_observations` - This function can only be called if the observation space setting is set to MULTIDISCRETE. - + """Update the observation based on current environment state. + The structure of the observation space is described in :class:`.NodeStatuses` """ obs = [] for _, node in self.env.nodes.items(): @@ -201,15 +204,28 @@ class NodeStatuses(AbstractObservationComponent): class LinkTrafficLevels(AbstractObservationComponent): - """TODO: complete description. + """Flat list of traffic levels encoded into banded categories. - Each link has one element in the observation space, corresponding to the traffic load, - it can take the following values: + For each link, total traffic or traffic per service is encoded into a categorical value. + For example, if ``quantisation_levels=5``, the traffic levels represent these values: 0 = No traffic (0% of bandwidth) 1 = No traffic (0%-33% of bandwidth) 2 = No traffic (33%-66% of bandwidth) 3 = No traffic (66%-100% of bandwidth) 4 = No traffic (100% of bandwidth) + + .. note:: + The lowest category always corresponds to no traffic and the highest category to the link being at max capacity. + Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories. + + :param env: The environment that forms the basis of the observations + :type env: Primaite + :param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually, + defaults to False + :type combine_service_traffic: bool, optional + :param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value , + defaults to 5 + :type quantisation_levels: int, optional """ _DATA_TYPE = np.int64 @@ -220,7 +236,10 @@ class LinkTrafficLevels(AbstractObservationComponent): combine_service_traffic: bool = False, quantisation_levels: int = 5, ): + assert quantisation_levels >= 3 + super().__init__(env) + self._combine_service_traffic: bool = combine_service_traffic self._quantisation_levels: int = quantisation_levels self._entries_per_link: int = 1 @@ -240,7 +259,10 @@ class LinkTrafficLevels(AbstractObservationComponent): self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) def update(self): - """TODO: complete description.""" + """Update the observation based on current environment state. + + The structure of the observation space is described in :class:`.LinkTrafficLevels` + """ obs = [] for _, link in self.env.links.items(): bandwidth = link.bandwidth @@ -265,7 +287,11 @@ class LinkTrafficLevels(AbstractObservationComponent): class ObservationsHandler: - """Component-based observation space handler.""" + """Component-based observation space handler. + + This allows users to configure observation spaces by mixing and matching components. + Each component can also define further parameters to make them more flexible. + """ registry = { "NODE_LINK_TABLE": NodeLinkTable, @@ -274,8 +300,6 @@ class ObservationsHandler: } def __init__(self): - """TODO: complete description.""" - """Initialise the handler without any components yet. They""" self.registered_obs_components: List[AbstractObservationComponent] = [] self.space: spaces.Space self.current_observation: Union[Tuple[np.ndarray], np.ndarray] @@ -283,7 +307,7 @@ class ObservationsHandler: # self.registry.LINK_TRAFFIC_LEVELS def update_obs(self): - """TODO: complete description.""" + """Fetch fresh information about the environment.""" current_obs = [] for obs in self.registered_obs_components: obs.update() @@ -296,17 +320,26 @@ class ObservationsHandler: self.current_observation = tuple(current_obs) def register(self, obs_component: AbstractObservationComponent): - """TODO: complete description.""" + """Add a component for this handler to track. + + :param obs_component: The component to add. + :type obs_component: AbstractObservationComponent + """ self.registered_obs_components.append(obs_component) self.update_space() def deregister(self, obs_component: AbstractObservationComponent): - """TODO: complete description.""" + """Remove a component from this handler. + + :param obs_component: Which component to remove. It must exist within this object's + ``registered_obs_components`` attribute. + :type obs_component: AbstractObservationComponent + """ self.registered_obs_components.remove(obs_component) self.update_space() def update_space(self): - """TODO: complete description.""" + """Rebuild the handler's composite observation space from its components.""" component_spaces = [] for obs_comp in self.registered_obs_components: component_spaces.append(obs_comp.space) @@ -319,10 +352,28 @@ class ObservationsHandler: @classmethod def from_config(cls, env: "Primaite", obs_space_config: dict): - """TODO: complete description. + """Parse a config dictinary, return a new observation handler populated with new observation component objects. - This method parses config items related to the observation space, then - creates the necessary components and adds them to the observation handler. + The expected format for the config dictionary is: + + ..code-block::python + config = { + components: [ + { + "name": "" + }, + { + "name": "" + "options": {"opt1": val1, "opt2": val2} + }, + { + ... + }, + ] + } + + :return: Observation handler + :rtype: primaite.environment.observations.ObservationsHandler """ # Instantiate the handler handler = cls() diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 81557075..8cff91d8 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -6,7 +6,7 @@ import csv import logging import os.path from datetime import datetime -from typing import Dict, Optional, Tuple +from typing import Dict, Tuple import networkx as nx import numpy as np @@ -643,16 +643,17 @@ class Primaite(Env): pass def init_observations(self) -> Tuple[spaces.Space, np.ndarray]: - """TODO: write docstring.""" + """Create the environment's observation handler. + + :return: The observation space, initial observation (zeroed out array with the correct shape) + :rtype: Tuple[spaces.Space, np.ndarray] + """ self.obs_handler = ObservationsHandler.from_config(self, self.obs_config) return self.obs_handler.space, self.obs_handler.current_observation def update_environent_obs(self): - """Updates the observation space based on the node and link status. - - TODO: better docstring - """ + """Updates the observation space based on the node and link status.""" self.obs_handler.update_obs() self.env_obs = self.obs_handler.current_observation @@ -1024,8 +1025,16 @@ class Primaite(Env): """ self.action_type = ActionType[action_info["type"]] - def save_obs_config(self, obs_config: Optional[Dict] = None): - """TODO: better docstring.""" + def save_obs_config(self, obs_config: dict): + """Cache the config for the observation space. + + This is necessary as the observation space can't be built while reading the config, + it must be done after all the nodes, links, and services have been initialised. + + :param obs_config: Parsed config relating to the observation space. The format is described in + :py:meth:`primaite.environment.observations.ObservationsHandler.from_config` + :type obs_config: dict + """ self.obs_config = obs_config def get_steps_info(self, steps_info):