Merge remote-tracking branch 'origin/dev' into feature/915_PRI-31_Packaging_Deployment

# Conflicts:
#	docs/source/about.rst
#	docs/source/config.rst
#	src/primaite/common/config_values_main.py
#	src/primaite/environment/primaite_env.py
#	src/primaite/main.py
#	tests/config/multidiscrete_obs_space_laydown_config.yaml
#	tests/config/obs_tests/laydown.yaml
#	tests/conftest.py
#	tests/test_observation_space.py
This commit is contained in:
Chris McCarthy
2023-06-09 13:41:05 +01:00
13 changed files with 1221 additions and 371 deletions

View File

@@ -0,0 +1,91 @@
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
"""The config class."""
class ConfigValuesMain(object):
"""Class to hold main config values."""
def __init__(self):
"""Init."""
# Generic
self.agent_identifier = "" # the agent in use
self.observation_config = None # observation space config
self.num_episodes = 0 # number of episodes to train over
self.num_steps = 0 # number of steps in an episode
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
self.config_filename_use_case = "" # the filename for the Use Case config file
self.session_type = "" # the session type to run (TRAINING or EVALUATION)
# Environment
self.observation_space_high_value = (
0 # The high value for the observation space
)
# Reward values
# Generic
self.all_ok = 0
# Node Hardware State
self.off_should_be_on = 0
self.off_should_be_resetting = 0
self.on_should_be_off = 0
self.on_should_be_resetting = 0
self.resetting_should_be_on = 0
self.resetting_should_be_off = 0
self.resetting = 0
# Node Software or Service State
self.good_should_be_patching = 0
self.good_should_be_compromised = 0
self.good_should_be_overwhelmed = 0
self.patching_should_be_good = 0
self.patching_should_be_compromised = 0
self.patching_should_be_overwhelmed = 0
self.patching = 0
self.compromised_should_be_good = 0
self.compromised_should_be_patching = 0
self.compromised_should_be_overwhelmed = 0
self.compromised = 0
self.overwhelmed_should_be_good = 0
self.overwhelmed_should_be_patching = 0
self.overwhelmed_should_be_compromised = 0
self.overwhelmed = 0
# Node File System State
self.good_should_be_repairing = 0
self.good_should_be_restoring = 0
self.good_should_be_corrupt = 0
self.good_should_be_destroyed = 0
self.repairing_should_be_good = 0
self.repairing_should_be_restoring = 0
self.repairing_should_be_corrupt = 0
self.repairing_should_be_destroyed = (
0 # Repairing does not fix destroyed state - you need to restore
)
self.repairing = 0
self.restoring_should_be_good = 0
self.restoring_should_be_repairing = 0
self.restoring_should_be_corrupt = (
0 # Not the optimal method (as repair will fix corruption)
)
self.restoring_should_be_destroyed = 0
self.restoring = 0
self.corrupt_should_be_good = 0
self.corrupt_should_be_repairing = 0
self.corrupt_should_be_restoring = 0
self.corrupt_should_be_destroyed = 0
self.corrupt = 0
self.destroyed_should_be_good = 0
self.destroyed_should_be_repairing = 0
self.destroyed_should_be_restoring = 0
self.destroyed_should_be_corrupt = 0
self.destroyed = 0
self.scanning = 0
# IER status
self.red_ier_running = 0
self.green_ier_blocked = 0
# Patching / Reset
self.os_patching_duration = 0 # The time taken to patch the OS
self.node_reset_duration = 0 # The time taken to reset a node (hardware)
self.service_patching_duration = 0 # The time taken to patch a service
self.file_system_repairing_limit = 0 # The time take to repair a file
self.file_system_restoring_limit = 0 # The time take to restore a file
self.file_system_scanning_limit = 0 # The time taken to scan the file system

View File

@@ -0,0 +1,403 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
import numpy as np
from gym import spaces
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
# This dependency is only needed for type hints,
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
# Therefore, this avoids circular dependency problem.
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: "Primaite"):
_LOGGER.info(f"Initialising {self} observation component")
self.env: "Primaite" = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
return NotImplemented
@abstractmethod
def update(self):
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes and links as rows and hardware/software status as cols.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
The number of columns is 4 plus one per service. They are:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
"""
_FIXED_PARAMETERS: int = 4
_MAX_VAL: int = 1_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
super().__init__(env)
# 1. Define the shape of your observation space component
num_items = self.env.num_links + self.env.num_nodes
num_columns = self.env.num_services + self._FIXED_PARAMETERS
observation_shape = (num_items, num_columns)
# 2. Create Observation space
self.space = spaces.Box(
low=0,
high=self._MAX_VAL,
shape=observation_shape,
dtype=self._DATA_TYPE,
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeLinkTable`
"""
item_index = 0
nodes = self.env.nodes
links = self.env.links
# Do nodes first
for _, node in nodes.items():
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][
3
] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][
service_index
] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.env.services_list:
self.current_observation[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for _, link in links.items():
self.current_observation[item_index][0] = int(link.get_id())
self.current_observation[item_index][1] = 0
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][
protocol_index + 4
] = protocol.get_load()
protocol_index += 1
item_index += 1
class NodeStatuses(AbstractObservationComponent):
"""Flat list of nodes' hardware, OS, file system, and service states.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each node has 3 elements plus 1 per service. It will have the following structure:
.. code-block::
[
node1 hardware state,
node1 OS state,
node1 file system state,
node1 service1 state,
node1 service2 state,
node1 serviceN state (one for each service),
node2 hardware state,
node2 OS state,
node2 file system state,
node2 service1 state,
node2 service2 state,
node2 serviceN state (one for each service),
...
]
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
super().__init__(env)
# 1. Define the shape of your observation space component
node_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
services_shape = [len(SoftwareState) + 1] * self.env.num_services
node_shape = node_shape + services_shape
shape = node_shape * self.env.num_nodes
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeStatuses`
"""
obs = []
for _, node in self.env.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
service_states = [0] * self.env.num_services
if isinstance(node, ActiveNode):
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
obs.extend(
[hardware_state, software_state, file_system_state, *service_states]
)
self.current_observation[:] = obs
class LinkTrafficLevels(AbstractObservationComponent):
"""Flat list of traffic levels encoded into banded categories.
For each link, total traffic or traffic per service is encoded into a categorical value.
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
.. note::
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
:param env: The environment that forms the basis of the observations
:type env: Primaite
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
defaults to False
:type combine_service_traffic: bool, optional
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value ,
defaults to 5
:type quantisation_levels: int, optional
"""
_DATA_TYPE: type = np.int64
def __init__(
self,
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
if quantisation_levels < 3:
_msg = (
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
f"Resetting to default value (5)"
)
_LOGGER.warning(_msg)
quantisation_levels = 5
super().__init__(env)
self._combine_service_traffic: bool = combine_service_traffic
self._quantisation_levels: int = quantisation_levels
self._entries_per_link: int = 1
if not self._combine_service_traffic:
self._entries_per_link = self.env.num_services
# 1. Define the shape of your observation space component
shape = (
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
)
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.LinkTrafficLevels`
"""
obs = []
for _, link in self.env.links.items():
bandwidth = link.bandwidth
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [protocol.get_load() for protocol in link.protocol_list]
for load in loads:
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (
1 / (self._quantisation_levels - 2)
) + 1
obs.append(int(traffic_level))
self.current_observation[:] = obs
class ObservationsHandler:
"""Component-based observation space handler.
This allows users to configure observation spaces by mixing and matching components.
Each component can also define further parameters to make them more flexible.
"""
_REGISTRY: Final[Dict[str, type]] = {
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
}
def __init__(self):
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
def update_obs(self):
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
obs.update()
current_obs.append(obs.current_observation)
# If there is only one component, don't use a tuple, just pass through that component's obs.
if len(current_obs) == 1:
self.current_observation = current_obs[0]
else:
self.current_observation = tuple(current_obs)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track.
:param obs_component: The component to add.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
"""Remove a component from this handler.
:param obs_component: Which component to remove. It must exist within this object's
``registered_obs_components`` attribute.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
# If there is only one component, don't use a tuple space, just pass through that component's space.
if len(component_spaces) == 1:
self.space = component_spaces[0]
else:
self.space = spaces.Tuple(component_spaces)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
The expected format for the config dictionary is:
..code-block::python
config = {
components: [
{
"name": "<COMPONENT1_NAME>"
},
{
"name": "<COMPONENT2_NAME>"
"options": {"opt1": val1, "opt2": val2}
},
{
...
},
]
}
:return: Observation handler
:rtype: primaite.environment.observations.ObservationsHandler
"""
# Instantiate the handler
handler = cls()
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
comp_class = cls._REGISTRY[comp_type]
# Create the component with options from the YAML
options = component_cfg.get("options") or {}
component = comp_class(env, **options)
handler.register(component)
handler.update_obs()
return handler

View File

@@ -24,11 +24,11 @@ from primaite.common.enums import (
NodePOLInitiator,
NodePOLType,
NodeType,
ObservationType,
Priority,
SoftwareState,
)
from primaite.common.service import Service
from primaite.environment.observations import ObservationsHandler
from primaite.config import training_config
from primaite.config.training_config import TrainingConfig
from primaite.environment.reward import calculate_reward_function
@@ -51,14 +51,11 @@ _LOGGER.setLevel(logging.INFO)
class Primaite(Env):
"""PRIMmary AI Training Evironment (Primaite) class."""
# Observation / Action Space contants
OBSERVATION_SPACE_FIXED_PARAMETERS = 4
ACTION_SPACE_NODE_PROPERTY_VALUES = 5
ACTION_SPACE_NODE_ACTION_VALUES = 4
ACTION_SPACE_ACL_ACTION_VALUES = 3
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
# Action Space contants
ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5
ACTION_SPACE_NODE_ACTION_VALUES: int = 4
ACTION_SPACE_ACL_ACTION_VALUES: int = 3
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__(
self,
@@ -165,8 +162,18 @@ class Primaite(Env):
# Number of ports - gets a value when config is loaded
self.num_ports = 0
# Observation type, by default box.
self.observation_type = ObservationType.BOX
# The action type
self.action_type = 0
# TODO fix up with TrainingConfig
# stores the observation config from the yaml, default is NODE_LINK_TABLE
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
if self.config_values.observation_config is not None:
self.obs_config = self.config_values.observation_config
# Observation Handler manages the user-configurable observation space.
# It will be initialised later.
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
@@ -229,7 +236,7 @@ class Primaite(Env):
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info(f"Invalid action type selected")
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
# Set up a csv to store the results of the training
try:
header = ["Episode", "Average Reward"]
@@ -424,9 +431,7 @@ class Primaite(Env):
_action: The action space from the agent
"""
# At the moment, actions are only affecting nodes
print("")
print(_action)
print(self.action_dict)
if self.training_config.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
elif self.training_config.action_type == ActionType.ACL:
@@ -652,252 +657,20 @@ class Primaite(Env):
else:
pass
def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the BOX option chosen.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
Columns are as follows:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
:return: Box gym observation
:rtype: gym.spaces.Box
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space type BOX selected")
# 1. Determine observation shape from laydown
num_items = self.num_links + self.num_nodes
num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
observation_shape = (num_items, num_observation_parameters)
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.Box(
low=0,
high=self.OBSERVATION_SPACE_HIGH_VALUE,
shape=observation_shape,
dtype=np.int64,
)
initial_observation = np.zeros(observation_shape, dtype=np.int64)
return observation_space, initial_observation
def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the MULTIDISCRETE option chosen.
This will create the observation space with node observations followed by link observations.
Each node has 3 elements in the observation space plus 1 per service, more specifically:
* hardware state
* operating system state
* file system state
* service states (one per service)
Each link has one element in the observation space, corresponding to the traffic load,
it can take the following values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(37,)``
:return: MultiDiscrete gym observation
:rtype: gym.spaces.MultiDiscrete
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space MULTIDISCRETE selected")
# 1. Determine observation shape from laydown
node_obs_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
node_services = [len(SoftwareState) + 1] * self.num_services
node_obs_shape = node_obs_shape + node_services
# the magic number 5 refers to 5 states of quantisation of traffic amount.
# (zero, low, medium, high, fully utilised/overwhelmed)
link_obs_shape = [5] * self.num_links
observation_shape = node_obs_shape * self.num_nodes + link_obs_shape
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.MultiDiscrete(observation_shape)
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
return observation_space, initial_observation
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Build the observation space based on network laydown and provide initial obs.
"""Create the environment's observation handler.
This method uses the object's `num_links`, `num_nodes`, `num_services`,
`OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type`
attributes to figure out the correct shape and format for the observation space.
:raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType`
:return: Gym observation space
:rtype: gym.spaces.Space
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
"""
if self.observation_type == ObservationType.BOX:
observation_space, initial_observation = self._init_box_observations()
return observation_space, initial_observation
elif self.observation_type == ObservationType.MULTIDISCRETE:
(
observation_space,
initial_observation,
) = self._init_multidiscrete_observations()
return observation_space, initial_observation
else:
errmsg = (
f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}"
f", got {self.observation_type} instead"
)
_LOGGER.error(errmsg)
raise ValueError(errmsg)
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
def _update_env_obs_box(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_box_observations`
This function can only be called if the observation space setting is set to BOX.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.BOX
item_index = 0
# Do nodes first
for node_key, node in self.nodes.items():
self.env_obs[item_index][0] = int(node.node_id)
self.env_obs[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.env_obs[item_index][2] = node.software_state.value
self.env_obs[item_index][3] = node.file_system_state_observed.value
else:
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.services_list:
if node.has_service(service):
self.env_obs[item_index][
service_index
] = node.get_service_state(service).value
else:
self.env_obs[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.services_list:
self.env_obs[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for link_key, link in self.links.items():
self.env_obs[item_index][0] = int(link.get_id())
self.env_obs[item_index][1] = 0
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
def _update_env_obs_multidiscrete(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_multidiscrete_observations`
This function can only be called if the observation space setting is set to MULTIDISCRETE.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.MULTIDISCRETE
obs = []
# 1. Set nodes
# Each node has the following variables in the observation space:
# - Hardware state
# - Software state
# - File System state
# - Service 1 state
# - Service 2 state
# - ...
# - Service N state
for node_key, node in self.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
services_states = [0] * self.num_services
if isinstance(
node, ActiveNode
): # ServiceNode is a subclass of ActiveNode so no need to check that also
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.services_list):
if node.has_service(service):
services_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
software_state,
file_system_state,
*services_states,
]
)
# 2. Set links
# Each link has just one variable in the observation space, it represents the traffic amount
# In order for the space to be fully MultiDiscrete, the amount of
# traffic on each link is quantised into a few levels:
# 0: no traffic (0% of bandwidth)
# 1: low traffic (0-33% of bandwidth)
# 2: medium traffic (33-66% of bandwidth)
# 3: high traffic (66-100% of bandwidth)
# 4: max traffic/overloaded (100% of bandwidth)
for link_key, link in self.links.items():
bandwidth = link.bandwidth
load = link.get_current_load()
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = 4
else:
traffic_level = (load / bandwidth) // (1 / 3) + 1
obs.append(int(traffic_level))
self.env_obs = np.asarray(obs)
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
"""Updates the observation space based on the node and link status."""
if self.observation_type == ObservationType.BOX:
self._update_env_obs_box()
elif self.observation_type == ObservationType.MULTIDISCRETE:
self._update_env_obs_multidiscrete()
self.obs_handler.update_obs()
self.env_obs = self.obs_handler.current_observation
def load_lay_down_config(self):
"""Loads config data in order to build the environment configuration."""
@@ -929,11 +702,9 @@ class Primaite(Env):
elif item["item_type"] == "PORTS":
# Create the list of ports
self.create_ports_list(item)
elif item["item_type"] == "OBSERVATIONS":
# Get the observation information
self.get_observation_info(item)
else:
# Do nothing (bad formatting)
item_type = item["item_type"]
_LOGGER.error(f"Invalid item_type: {item_type}")
pass
_LOGGER.info("Environment configuration loaded")
@@ -1260,6 +1031,28 @@ class Primaite(Env):
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_action_info(self, action_info):
"""
Extracts action_info.
Args:
item: A config data item representing action info
"""
self.action_type = ActionType[action_info["type"]]
def save_obs_config(self, obs_config: dict):
"""Cache the config for the observation space.
This is necessary as the observation space can't be built while reading the config,
it must be done after all the nodes, links, and services have been initialised.
:param obs_config: Parsed config relating to the observation space. The format is described in
:py:meth:`primaite.environment.observations.ObservationsHandler.from_config`
:type obs_config: dict
"""
self.obs_config = obs_config
def reset_environment(self):
"""
# Resets environment.