Merge branch 'feature/1463-multidiscrete-observation-option' into feature/1468-observations-class

This commit is contained in:
Marek Wolan
2023-06-01 11:09:21 +01:00
5 changed files with 353 additions and 178 deletions

View File

@@ -0,0 +1,12 @@
## Summary
*Replace this text with an explanation of what the changes are and how you implemented them. Can this impact any other parts of the codebase that we should keep in mind?*
## Test process
*How have you tested this (if applicable)?*
## Checklist
- [ ] This PR is linked to a **work item**
- [ ] I have performed **self-review** of the code
- [ ] I have written **tests** for any new functionality added with this PR
- [ ] I have updated the **documentation** if this PR changes or adds functionality
- [ ] I have run **pre-commit** checks for code style

View File

@@ -6,7 +6,7 @@ import csv
import logging
import os.path
from datetime import datetime
from typing import Dict
from typing import Dict, Tuple
import networkx as nx
import numpy as np
@@ -641,172 +641,252 @@ class Primaite(Env):
else:
pass
def init_observations(self):
def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the BOX option chosen.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
Columns are as follows:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
:return: Box gym observation
:rtype: gym.spaces.Box
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space type BOX selected")
# 1. Determine observation shape from laydown
num_items = self.num_links + self.num_nodes
num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
observation_shape = (num_items, num_observation_parameters)
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.Box(
low=0,
high=self.OBSERVATION_SPACE_HIGH_VALUE,
shape=observation_shape,
dtype=np.int64,
)
initial_observation = np.zeros(observation_shape, dtype=np.int64)
return observation_space, initial_observation
def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the MULTIDISCRETE option chosen.
This will create the observation space with node observations followed by link observations.
Each node has 3 elements in the observation space plus 1 per service, more specifically:
* hardware state
* operating system state
* file system state
* service states (one per service)
Each link has one element in the observation space, corresponding to the traffic load,
it can take the following values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(37,)``
:return: MultiDiscrete gym observation
:rtype: gym.spaces.MultiDiscrete
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space MULTIDISCRETE selected")
# 1. Determine observation shape from laydown
node_obs_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
node_services = [len(SoftwareState) + 1] * self.num_services
node_obs_shape = node_obs_shape + node_services
# the magic number 5 refers to 5 states of quantisation of traffic amount.
# (zero, low, medium, high, fully utilised/overwhelmed)
link_obs_shape = [5] * self.num_links
observation_shape = node_obs_shape * self.num_nodes + link_obs_shape
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.MultiDiscrete(observation_shape)
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
return observation_space, initial_observation
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Build the observation space based on network laydown and provide initial obs.
This method uses the object's `num_links`, `num_nodes`, `num_services`,
`OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type`
attributes to figure out the correct shape and format for the observation space.
Returns
-------
gym.spaces.Space
Gym observation space
numpy.Array
Initial observation with all entries set to 0
:raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType`
:return: Gym observation space
:rtype: gym.spaces.Space
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
if self.observation_type == ObservationType.BOX:
_LOGGER.info("Observation space type BOX selected")
# 1. Determine observation shape from laydown
num_items = self.num_links + self.num_nodes
num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
observation_shape = (num_items, num_observation_parameters)
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.Box(
low=0,
high=self.OBSERVATION_SPACE_HIGH_VALUE,
shape=observation_shape,
dtype=np.int64,
)
initial_observation = np.zeros(observation_shape, dtype=np.int64)
observation_space, initial_observation = self._init_box_observations()
return observation_space, initial_observation
elif self.observation_type == ObservationType.MULTIDISCRETE:
_LOGGER.info("Observation space MULTIDISCRETE selected")
# 1. Determine observation shape from laydown
node_obs_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
node_services = [len(SoftwareState) + 1] * self.num_services
node_obs_shape = node_obs_shape + node_services
# the magic number 5 refers to 5 states of quantisation of traffic amount.
# (zero, low, medium, high, fully utilised/overwhelmed)
link_obs_shape = [5] * self.num_links
observation_shape = node_obs_shape * self.num_nodes + link_obs_shape
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.MultiDiscrete(observation_shape)
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
(
observation_space,
initial_observation,
) = self._init_multidiscrete_observations()
return observation_space, initial_observation
else:
raise ValueError(
errmsg = (
f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}"
f", got {self.observation_type} instead"
)
_LOGGER.error(errmsg)
raise ValueError(errmsg)
return observation_space, initial_observation
def _update_env_obs_box(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_box_observations`
This function can only be called if the observation space setting is set to BOX.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.BOX
item_index = 0
# Do nodes first
for node_key, node in self.nodes.items():
self.env_obs[item_index][0] = int(node.node_id)
self.env_obs[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.env_obs[item_index][2] = node.software_state.value
self.env_obs[item_index][3] = node.file_system_state_observed.value
else:
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.services_list:
if node.has_service(service):
self.env_obs[item_index][
service_index
] = node.get_service_state(service).value
else:
self.env_obs[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.services_list:
self.env_obs[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for link_key, link in self.links.items():
self.env_obs[item_index][0] = int(link.get_id())
self.env_obs[item_index][1] = 0
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
def _update_env_obs_multidiscrete(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_multidiscrete_observations`
This function can only be called if the observation space setting is set to MULTIDISCRETE.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.MULTIDISCRETE
obs = []
# 1. Set nodes
# Each node has the following variables in the observation space:
# - Hardware state
# - Software state
# - File System state
# - Service 1 state
# - Service 2 state
# - ...
# - Service N state
for node_key, node in self.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
services_states = [0] * self.num_services
if isinstance(
node, ActiveNode
): # ServiceNode is a subclass of ActiveNode so no need to check that also
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.services_list):
if node.has_service(service):
services_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
software_state,
file_system_state,
*services_states,
]
)
# 2. Set links
# Each link has just one variable in the observation space, it represents the traffic amount
# In order for the space to be fully MultiDiscrete, the amount of
# traffic on each link is quantised into a few levels:
# 0: no traffic (0% of bandwidth)
# 1: low traffic (0-33% of bandwidth)
# 2: medium traffic (33-66% of bandwidth)
# 3: high traffic (66-100% of bandwidth)
# 4: max traffic/overloaded (100% of bandwidth)
for link_key, link in self.links.items():
bandwidth = link.bandwidth
load = link.get_current_load()
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = 4
else:
traffic_level = (load / bandwidth) // (1 / 3) + 1
obs.append(int(traffic_level))
self.env_obs = np.asarray(obs)
def update_environent_obs(self):
"""Updates the observation space based on the node and link status."""
if self.observation_type == ObservationType.BOX:
item_index = 0
# Do nodes first
for node_key, node in self.nodes.items():
self.env_obs[item_index][0] = int(node.node_id)
self.env_obs[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.env_obs[item_index][2] = node.software_state.value
self.env_obs[item_index][3] = node.file_system_state_observed.value
else:
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.services_list:
if node.has_service(service):
self.env_obs[item_index][
service_index
] = node.get_service_state(service).value
else:
self.env_obs[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.services_list:
self.env_obs[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for link_key, link in self.links.items():
self.env_obs[item_index][0] = int(link.get_id())
self.env_obs[item_index][1] = 0
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
self._update_env_obs_box()
elif self.observation_type == ObservationType.MULTIDISCRETE:
obs = []
# 1. Set nodes
# Each node has the following variables in the observation space:
# - Hardware state
# - Software state
# - File System state
# - Service 1 state
# - Service 2 state
# - ...
# - Service N state
for node_key, node in self.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
services_states = [0] * self.num_services
if isinstance(
node, ActiveNode
): # ServiceNode is a subclass of ActiveNode so no need to check that also
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.services_list):
if node.has_service(service):
services_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
software_state,
file_system_state,
*services_states,
]
)
# 2. Set links
# Each link has just one variable in the observation space, it represents the traffic amount
# In order for the space to be fully MultiDiscrete, the amount of
# traffic on each link is quantised into a few levels:
# 0: no traffic (0% of bandwidth)
# 1: low traffic (0-33% of bandwidth)
# 2: medium traffic (33-66% of bandwidth)
# 3: high traffic (66-100% of bandwidth)
# 4: max traffic/overloaded (100% of bandwidth)
for link_key, link in self.links.items():
bandwidth = link.bandwidth
load = link.get_current_load()
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = 4
else:
traffic_level = (load / bandwidth) // (1 / 3) + 1
obs.append(int(traffic_level))
self.env_obs = np.asarray(obs)
self._update_env_obs_multidiscrete()
def load_config(self):
"""Loads config data in order to build the environment configuration."""
@@ -1179,10 +1259,8 @@ class Primaite(Env):
def get_observation_info(self, observation_info):
"""Extracts observation_info.
Parameters
----------
observation_info : str
Config item that defines which type of observation space to use
:param observation_info: Config item that defines which type of observation space to use
:type observation_info: str
"""
self.observation_type = ObservationType[observation_info["type"]]

View File

@@ -6,26 +6,63 @@
steps: 5
- itemType: PORTS
portsList:
- port: '21'
- port: '80'
- itemType: SERVICES
serviceList:
- name: ftp
- name: TCP
########################################
# Nodes
- itemType: NODE
node_id: '1'
name: node
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P1
priority: P5
hardware_state: 'ON'
ip_address: 192.168.0.1
ip_address: 192.168.1.1
software_state: GOOD
file_system_state: GOOD
services:
- name: ftp
port: '21'
state: GOOD
- itemType: POSITION
positions:
- node: '1'
x_pos: 309
y_pos: 78
- name: TCP
port: '80'
state: GOOD
- itemType: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- itemType: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
########################################
# Links
- itemType: LINK
id: '4'
name: link1
bandwidth: 1000
source: '1'
destination: '3'
- itemType: LINK
id: '5'
name: link2
bandwidth: 1000
source: '3'
destination: '2'

View File

@@ -6,26 +6,63 @@
steps: 5
- itemType: PORTS
portsList:
- port: '21'
- port: '80'
- itemType: SERVICES
serviceList:
- name: ftp
- name: TCP
########################################
# Nodes
- itemType: NODE
node_id: '1'
name: node
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P1
priority: P5
hardware_state: 'ON'
ip_address: 192.168.0.1
ip_address: 192.168.1.1
software_state: GOOD
file_system_state: GOOD
services:
- name: ftp
port: '21'
state: GOOD
- itemType: POSITION
positions:
- node: '1'
x_pos: 309
y_pos: 78
- name: TCP
port: '80'
state: GOOD
- itemType: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- itemType: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
########################################
# Links
- itemType: LINK
id: '4'
name: link1
bandwidth: 1000
source: '1'
destination: '3'
- itemType: LINK
id: '5'
name: link2
bandwidth: 1000
source: '3'
destination: '2'

View File

@@ -12,6 +12,12 @@ def test_creating_env_with_box_obs():
)
env.update_environent_obs()
# we have three nodes and two links, with one service
# therefore the box observation space will have:
# * 5 columns (four fixed and one for the service)
# * 5 rows (3 nodes + 2 links)
assert env.env_obs.shape == (5, 5)
def test_creating_env_with_multidiscrete_obs():
"""Try creating env with MultiDiscrete observation space."""
@@ -21,3 +27,8 @@ def test_creating_env_with_multidiscrete_obs():
/ "multidiscrete_obs_space_laydown_config.yaml",
)
env.update_environent_obs()
# we have three nodes and two links, with one service
# the nodes have hardware, OS, FS, and service, the links just have bandwidth,
# therefore we need 3*4 + 2 observations
assert env.env_obs.shape == (3 * 4 + 2,)