Integrate observation handler with components

This commit is contained in:
Marek Wolan
2023-06-01 13:28:40 +01:00
parent fe32d83782
commit d351e575ae

View File

@@ -1,227 +1,256 @@
# """Module for handling configurable observation spaces in PrimAITE."""
# import logging
# from abc import ABC, abstractmethod
# from enum import Enum
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from enum import Enum
from typing import List, Tuple
# import numpy as np
# from gym import spaces
import numpy as np
from gym import spaces
# from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
# from primaite.environment.primaite_env import Primaite
# from primaite.nodes.active_node import ActiveNode
# from primaite.nodes.service_node import ServiceNode
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.environment.primaite_env import Primaite
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
# _LOGGER = logging.getLogger(__name__)
_LOGGER = logging.getLogger(__name__)
# class AbstractObservationComponent(ABC):
# """Represents a part of the PrimAITE observation space."""
# @abstractmethod
# def __init__(self, env: Primaite):
# _LOGGER.info(f"Initialising {self} observation component")
# self.env: Primaite = env
# self.space: spaces.Space
# self.current_observation: np.ndarray # type might be too restrictive?
# return NotImplemented
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
# @abstractmethod
# def update(self):
# """Look at the environment and update the current observation value"""
# self.current_observation = NotImplemented
@abstractmethod
def __init__(self, env: Primaite):
_LOGGER.info(f"Initialising {self} observation component")
self.env: Primaite = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
return NotImplemented
# # @abstractmethod
# # def export(self):
# # return NotImplemented
@abstractmethod
def update(self):
"""Look at the environment and update the current observation value."""
self.current_observation = NotImplemented
# class NodeLinkTable(AbstractObservationComponent):
# """Table with nodes/links as rows and hardware/software status as cols.
class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes/links as rows and hardware/software status as cols.
# #todo: write full description
#todo: write full description
# """
"""
# _FIXED_PARAMETERS = 4
# _MAX_VAL = 1_000_000
# _DATA_TYPE = np.int64
_FIXED_PARAMETERS = 4
_MAX_VAL = 1_000_000
_DATA_TYPE = np.int64
# def __init__(self, env: Primaite):
# super().__init__()
def __init__(self, env: Primaite):
super().__init__()
# # 1. Define the shape of your observation space component
# num_items = self.env.num_links + self.env.num_nodes
# num_columns = self.env.num_services + self._FIXED_PARAMETERS
# observation_shape = (num_items, num_columns)
# 1. Define the shape of your observation space component
num_items = self.env.num_links + self.env.num_nodes
num_columns = self.env.num_services + self._FIXED_PARAMETERS
observation_shape = (num_items, num_columns)
# # 2. Create Observation space
# self.space = spaces.Box(
# low=0,
# high=self._MAX_VAL,
# shape=observation_shape,
# dtype=self._DATA_TYPE,
# )
# 2. Create Observation space
self.space = spaces.Box(
low=0,
high=self._MAX_VAL,
shape=observation_shape,
dtype=self._DATA_TYPE,
)
# # 3. Initialise Observation with zeroes
# self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeroes(observation_shape, dtype=self._DATA_TYPE)
# def update_obs(self):
# item_index = 0
# nodes = self.env.nodes
# links = self.env.links
# # Do nodes first
# for _, node in nodes.items():
# self.current_observation[item_index][0] = int(node.node_id)
# self.current_observation[item_index][1] = node.hardware_state.value
# if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
# self.current_observation[item_index][2] = node.software_state.value
# self.current_observation[item_index][
# 3
# ] = node.file_system_state_observed.value
# else:
# self.current_observation[item_index][2] = 0
# self.current_observation[item_index][3] = 0
# service_index = 4
# if isinstance(node, ServiceNode):
# for service in self.env.services_list:
# if node.has_service(service):
# self.current_observation[item_index][
# service_index
# ] = node.get_service_state(service).value
# else:
# self.current_observation[item_index][service_index] = 0
# service_index += 1
# else:
# # Not a service node
# for service in self.env.services_list:
# self.current_observation[item_index][service_index] = 0
# service_index += 1
# item_index += 1
def update_obs(self):
"""Update the observation.
# # Now do links
# for _, link in links.items():
# self.current_observation[item_index][0] = int(link.get_id())
# self.current_observation[item_index][1] = 0
# self.current_observation[item_index][2] = 0
# self.current_observation[item_index][3] = 0
# protocol_list = link.get_protocol_list()
# protocol_index = 0
# for protocol in protocol_list:
# self.current_observation[item_index][
# protocol_index + 4
# ] = protocol.get_load()
# protocol_index += 1
# item_index += 1
todo: complete description..
"""
item_index = 0
nodes = self.env.nodes
links = self.env.links
# Do nodes first
for _, node in nodes.items():
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][
3
] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][
service_index
] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.env.services_list:
self.current_observation[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for _, link in links.items():
self.current_observation[item_index][0] = int(link.get_id())
self.current_observation[item_index][1] = 0
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][
protocol_index + 4
] = protocol.get_load()
protocol_index += 1
item_index += 1
# class NodeStatuses(AbstractObservationComponent):
# _DATA_TYPE = np.int64
class NodeStatuses(AbstractObservationComponent):
"""todo: complete description."""
# def __init__(self):
# super().__init__()
_DATA_TYPE = np.int64
# # 1. Define the shape of your observation space component
# shape = [
# len(HardwareState) + 1,
# len(SoftwareState) + 1,
# len(FileSystemState) + 1,
# ]
# services_shape = [len(SoftwareState) + 1] * self.env.num_services
# shape = shape + services_shape
def __init__(self):
super().__init__()
# # 2. Create Observation space
# self.space = spaces.MultiDiscrete(shape)
# 1. Define the shape of your observation space component
shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
services_shape = [len(SoftwareState) + 1] * self.env.num_services
shape = shape + services_shape
# # 3. Initialise observation with zeroes
# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# def update_obs(self):
# obs = []
# for _, node in self.env.nodes.items():
# hardware_state = node.hardware_state.value
# software_state = 0
# file_system_state = 0
# service_states = [0] * self.env.num_services
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
# if isinstance(node, ActiveNode):
# software_state = node.software_state.value
# file_system_state = node.file_system_state_observed.value
def update_obs(self):
"""todo: complete description."""
obs = []
for _, node in self.env.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
service_states = [0] * self.env.num_services
# if isinstance(node, ServiceNode):
# for i, service in enumerate(self.env.services_list):
# if node.has_service(service):
# service_states[i] = node.get_service_state(service).value
# obs.extend([hardware_state, software_state, file_system_state, *service_states])
# self.current_observation[:] = obs
if isinstance(node, ActiveNode):
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
obs.extend([hardware_state, software_state, file_system_state, *service_states])
self.current_observation[:] = obs
# class LinkTrafficLevels(AbstractObservationComponent):
# _DATA_TYPE = np.int64
class LinkTrafficLevels(AbstractObservationComponent):
"""todo: complete description."""
# def __init__(
# self, combine_service_traffic: bool = False, quantisation_levels: int = 5
# ):
# super().__init__()
# self._combine_service_traffic: bool = combine_service_traffic
# self._quantisation_levels: int = quantisation_levels
# self._entries_per_link: int = 1
_DATA_TYPE = np.int64
# if not self._combine_service_traffic:
# self._entries_per_link = self.env.num_services
def __init__(
self, combine_service_traffic: bool = False, quantisation_levels: int = 5
):
super().__init__()
self._combine_service_traffic: bool = combine_service_traffic
self._quantisation_levels: int = quantisation_levels
self._entries_per_link: int = 1
# # 1. Define the shape of your observation space component
# shape = (
# [self._quantisation_levels] * self.env.num_links * self._entries_per_link
# )
if not self._combine_service_traffic:
self._entries_per_link = self.env.num_services
# # 2. Create Observation space
# self.space = spaces.MultiDiscrete(shape)
# 1. Define the shape of your observation space component
shape = (
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
)
# # 3. Initialise observation with zeroes
# self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# def update_obs(self):
# obs = []
# for _, link in self.env.links.items():
# bandwidth = link.bandwidth
# if self._combine_service_traffic:
# loads = [link.get_current_load()]
# else:
# loads = [protocol.get_load() for protocol in link.protocol_list]
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
# for load in loads:
# if load <= 0:
# traffic_level = 0
# elif load >= bandwidth:
# traffic_level = self._quantisation_levels - 1
# else:
# traffic_level = (load / bandwidth) // (
# 1 / (self._quantisation_levels - 1)
# ) + 1
def update_obs(self):
"""todo: complete description."""
obs = []
for _, link in self.env.links.items():
bandwidth = link.bandwidth
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [protocol.get_load() for protocol in link.protocol_list]
# obs.append(int(traffic_level))
for load in loads:
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (
1 / (self._quantisation_levels - 1)
) + 1
# self.current_observation[:] = obs
obs.append(int(traffic_level))
self.current_observation[:] = obs
# class ObservationsHandler:
# class registry(Enum):
# NODE_LINK_TABLE: NodeLinkTable
# NODE_STATUSES: NodeStatuses
# LINK_TRAFFIC_LEVELS: LinkTrafficLevels
class ObservationsHandler:
"""todo: complete description."""
# def __init__(self):
# ...
# # i can access the registry items like this:
# # self.registry.LINK_TRAFFIC_LEVELS
class registry(Enum):
"""todo: complete description."""
# def update_obs(self):
# ...
NODE_LINK_TABLE: NodeLinkTable
NODE_STATUSES: NodeStatuses
LINK_TRAFFIC_LEVELS: LinkTrafficLevels
# def register(self):
# ...
def __init__(self):
"""todo: complete description."""
"""Initialise the handler without any components yet. They"""
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Tuple[np.ndarray]
# i can access the registry items like this:
# self.registry.LINK_TRAFFIC_LEVELS
# def deregister(self, observation: AbstractObservationComponent):
# ...
def update_obs(self):
"""todo: complete description."""
current_obs = []
for obs in self.registered_obs_components:
obs.update_obs()
current_obs.append(obs.current_observation)
self.current_observation = tuple(current_obs)
# def export(self):
# ...
def register(self, obs_component: AbstractObservationComponent):
"""todo: complete description."""
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
"""todo: complete description."""
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
"""todo: complete description."""
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
self.space = spaces.Tuple(component_spaces)