Merge branch 'dev' into feature/898-Fix-the-functionality-of-resetting-a-node

This commit is contained in:
Brian Kanyora
2023-06-02 14:56:31 +01:00
12 changed files with 554 additions and 130 deletions

View File

@@ -6,7 +6,7 @@ import csv
import logging
import os.path
from datetime import datetime
from typing import Dict
from typing import Dict, Tuple
import networkx as nx
import numpy as np
@@ -23,6 +23,7 @@ from primaite.common.enums import (
NodePOLInitiator,
NodePOLType,
NodeType,
ObservationType,
Priority,
SoftwareState,
)
@@ -148,6 +149,9 @@ class Primaite(Env):
# The action type
self.action_type = 0
# Observation type, by default box.
self.observation_type = ObservationType.BOX
# Open the config file and build the environment laydown
try:
self.config_file = open(self.config_values.config_filename_use_case, "r")
@@ -187,42 +191,8 @@ class Primaite(Env):
_LOGGER.error("Exception occured", exc_info=True)
print("Could not save network diagram")
# Define Observation Space
# x = number of nodes and links (i.e. items)
# y = number of parameters to be sent
# For each item, we send:
# - [For Nodes] | [For Links]
# - node ID | link ID
# - hardware state | N/A
# - Software State | N/A
# - file system state | N/A
# - service A state | service A loading
# - service B state | service B loading
# - service C state | service C loading
# - service D state | service D loading
# - service E state | service E loading
# - service F state | service F loading
# - service G state | service G loading
# Calculate the number of items that need to be included in the
# observation space
num_items = self.num_links + self.num_nodes
# Set the number of observation parameters, being # of services plus id,
# hardware state, file system state and SoftwareState (i.e. 4)
self.num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
# Define the observation shape
self.observation_shape = (num_items, self.num_observation_parameters)
self.observation_space = spaces.Box(
low=0,
high=self.config_values.observation_space_high_value,
shape=self.observation_shape,
dtype=np.int64,
)
# This is the observation that is sent back via the rest and step functions
self.env_obs = np.zeros(self.observation_shape, dtype=np.int64)
# Initiate observation space
self.observation_space, self.env_obs = self.init_observations()
# Define Action Space - depends on action space type (Node or ACL)
if self.action_type == ActionType.NODE:
@@ -671,8 +641,134 @@ class Primaite(Env):
else:
pass
def update_environent_obs(self):
"""Updates the observation space based on the node and link status."""
def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the BOX option chosen.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
Columns are as follows:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
:return: Box gym observation
:rtype: gym.spaces.Box
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space type BOX selected")
# 1. Determine observation shape from laydown
num_items = self.num_links + self.num_nodes
num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
observation_shape = (num_items, num_observation_parameters)
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.Box(
low=0,
high=self.OBSERVATION_SPACE_HIGH_VALUE,
shape=observation_shape,
dtype=np.int64,
)
initial_observation = np.zeros(observation_shape, dtype=np.int64)
return observation_space, initial_observation
def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the MULTIDISCRETE option chosen.
This will create the observation space with node observations followed by link observations.
Each node has 3 elements in the observation space plus 1 per service, more specifically:
* hardware state
* operating system state
* file system state
* service states (one per service)
Each link has one element in the observation space, corresponding to the traffic load,
it can take the following values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(37,)``
:return: MultiDiscrete gym observation
:rtype: gym.spaces.MultiDiscrete
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space MULTIDISCRETE selected")
# 1. Determine observation shape from laydown
node_obs_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
node_services = [len(SoftwareState) + 1] * self.num_services
node_obs_shape = node_obs_shape + node_services
# the magic number 5 refers to 5 states of quantisation of traffic amount.
# (zero, low, medium, high, fully utilised/overwhelmed)
link_obs_shape = [5] * self.num_links
observation_shape = node_obs_shape * self.num_nodes + link_obs_shape
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.MultiDiscrete(observation_shape)
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
return observation_space, initial_observation
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Build the observation space based on network laydown and provide initial obs.
This method uses the object's `num_links`, `num_nodes`, `num_services`,
`OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type`
attributes to figure out the correct shape and format for the observation space.
:raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType`
:return: Gym observation space
:rtype: gym.spaces.Space
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
if self.observation_type == ObservationType.BOX:
observation_space, initial_observation = self._init_box_observations()
return observation_space, initial_observation
elif self.observation_type == ObservationType.MULTIDISCRETE:
(
observation_space,
initial_observation,
) = self._init_multidiscrete_observations()
return observation_space, initial_observation
else:
errmsg = (
f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}"
f", got {self.observation_type} instead"
)
_LOGGER.error(errmsg)
raise ValueError(errmsg)
def _update_env_obs_box(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_box_observations`
This function can only be called if the observation space setting is set to BOX.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.BOX
item_index = 0
# Do nodes first
@@ -715,6 +811,83 @@ class Primaite(Env):
protocol_index += 1
item_index += 1
def _update_env_obs_multidiscrete(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_multidiscrete_observations`
This function can only be called if the observation space setting is set to MULTIDISCRETE.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.MULTIDISCRETE
obs = []
# 1. Set nodes
# Each node has the following variables in the observation space:
# - Hardware state
# - Software state
# - File System state
# - Service 1 state
# - Service 2 state
# - ...
# - Service N state
for node_key, node in self.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
services_states = [0] * self.num_services
if isinstance(
node, ActiveNode
): # ServiceNode is a subclass of ActiveNode so no need to check that also
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.services_list):
if node.has_service(service):
services_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
software_state,
file_system_state,
*services_states,
]
)
# 2. Set links
# Each link has just one variable in the observation space, it represents the traffic amount
# In order for the space to be fully MultiDiscrete, the amount of
# traffic on each link is quantised into a few levels:
# 0: no traffic (0% of bandwidth)
# 1: low traffic (0-33% of bandwidth)
# 2: medium traffic (33-66% of bandwidth)
# 3: high traffic (66-100% of bandwidth)
# 4: max traffic/overloaded (100% of bandwidth)
for link_key, link in self.links.items():
bandwidth = link.bandwidth
load = link.get_current_load()
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = 4
else:
traffic_level = (load / bandwidth) // (1 / 3) + 1
obs.append(int(traffic_level))
self.env_obs = np.asarray(obs)
def update_environent_obs(self):
"""Updates the observation space based on the node and link status."""
if self.observation_type == ObservationType.BOX:
self._update_env_obs_box()
elif self.observation_type == ObservationType.MULTIDISCRETE:
self._update_env_obs_multidiscrete()
def load_config(self):
"""Loads config data in order to build the environment configuration."""
for item in self.config_data:
@@ -748,6 +921,9 @@ class Primaite(Env):
elif item["itemType"] == "ACTIONS":
# Get the action information
self.get_action_info(item)
elif item["itemType"] == "OBSERVATIONS":
# Get the observation information
self.get_observation_info(item)
elif item["itemType"] == "STEPS":
# Get the steps information
self.get_steps_info(item)
@@ -1080,6 +1256,14 @@ class Primaite(Env):
"""
self.action_type = ActionType[action_info["type"]]
def get_observation_info(self, observation_info):
"""Extracts observation_info.
:param observation_info: Config item that defines which type of observation space to use
:type observation_info: str
"""
self.observation_type = ObservationType[observation_info["type"]]
def get_steps_info(self, steps_info):
"""
Extracts steps_info.

View File

@@ -93,7 +93,6 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
"""
score = 0
final_node_operating_state = final_node.hardware_state
initial_node_operating_state = initial_node.hardware_state
reference_node_operating_state = reference_node.hardware_state
if final_node_operating_state == reference_node_operating_state:
@@ -101,27 +100,27 @@ def score_node_operating_state(final_node, initial_node, reference_node, config_
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and reference (current) state of node (i.e. at every step)
if initial_node_operating_state == HardwareState.ON:
if reference_node_operating_state == HardwareState.OFF:
# Need to compare reference and final (current) state of node (i.e. at every step)
if reference_node_operating_state == HardwareState.ON:
if final_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_on
elif reference_node_operating_state == HardwareState.RESETTING:
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_on
else:
pass
elif initial_node_operating_state == HardwareState.OFF:
if reference_node_operating_state == HardwareState.ON:
elif reference_node_operating_state == HardwareState.OFF:
if final_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_off
elif reference_node_operating_state == HardwareState.RESETTING:
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting_should_be_off
else:
pass
elif initial_node_operating_state == HardwareState.RESETTING:
if reference_node_operating_state == HardwareState.ON:
elif reference_node_operating_state == HardwareState.RESETTING:
if final_node_operating_state == HardwareState.ON:
score += config_values.on_should_be_resetting
elif reference_node_operating_state == HardwareState.OFF:
elif final_node_operating_state == HardwareState.OFF:
score += config_values.off_should_be_resetting
elif reference_node_operating_state == HardwareState.RESETTING:
elif final_node_operating_state == HardwareState.RESETTING:
score += config_values.resetting
else:
pass
@@ -143,7 +142,6 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
"""
score = 0
final_node_os_state = final_node.software_state
initial_node_os_state = initial_node.software_state
reference_node_os_state = reference_node.software_state
if final_node_os_state == reference_node_os_state:
@@ -151,29 +149,29 @@ def score_node_os_state(final_node, initial_node, reference_node, config_values)
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and reference (current) state of node (i.e. at every step)
if initial_node_os_state == SoftwareState.GOOD:
if reference_node_os_state == SoftwareState.PATCHING:
# Need to compare reference and final (current) state of node (i.e. at every step)
if reference_node_os_state == SoftwareState.GOOD:
if final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif reference_node_os_state == SoftwareState.COMPROMISED:
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
else:
pass
elif initial_node_os_state == SoftwareState.PATCHING:
if reference_node_os_state == SoftwareState.GOOD:
elif reference_node_os_state == SoftwareState.PATCHING:
if final_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif reference_node_os_state == SoftwareState.COMPROMISED:
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif reference_node_os_state == SoftwareState.PATCHING:
elif final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif initial_node_os_state == SoftwareState.COMPROMISED:
if reference_node_os_state == SoftwareState.GOOD:
elif reference_node_os_state == SoftwareState.COMPROMISED:
if final_node_os_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif reference_node_os_state == SoftwareState.PATCHING:
elif final_node_os_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif reference_node_os_state == SoftwareState.COMPROMISED:
elif final_node_os_state == SoftwareState.COMPROMISED:
score += config_values.compromised
else:
pass
@@ -195,58 +193,57 @@ def score_node_service_state(final_node, initial_node, reference_node, config_va
"""
score = 0
final_node_services: Dict[str, Service] = final_node.services
initial_node_services: Dict[str, Service] = initial_node.services
reference_node_services: Dict[str, Service] = reference_node.services
for service_key, final_service in final_node_services.items():
reference_service = reference_node_services[service_key]
initial_service = initial_node_services[service_key]
final_service = final_node_services[service_key]
if final_service.software_state == reference_service.software_state:
# All is well - we're no different from the reference situation
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and reference state of node (i.e. at every step)
if initial_service.software_state == SoftwareState.GOOD:
if reference_service.software_state == SoftwareState.PATCHING:
# Need to compare reference and final state of node (i.e. at every step)
if reference_service.software_state == SoftwareState.GOOD:
if final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_good
elif reference_service.software_state == SoftwareState.COMPROMISED:
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_good
elif reference_service.software_state == SoftwareState.OVERWHELMED:
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_good
else:
pass
elif initial_service.software_state == SoftwareState.PATCHING:
if reference_service.software_state == SoftwareState.GOOD:
elif reference_service.software_state == SoftwareState.PATCHING:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_patching
elif reference_service.software_state == SoftwareState.COMPROMISED:
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_patching
elif reference_service.software_state == SoftwareState.OVERWHELMED:
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_patching
elif reference_service.software_state == SoftwareState.PATCHING:
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching
else:
pass
elif initial_service.software_state == SoftwareState.COMPROMISED:
if reference_service.software_state == SoftwareState.GOOD:
elif reference_service.software_state == SoftwareState.COMPROMISED:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_compromised
elif reference_service.software_state == SoftwareState.PATCHING:
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_compromised
elif reference_service.software_state == SoftwareState.COMPROMISED:
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised
elif reference_service.software_state == SoftwareState.OVERWHELMED:
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed_should_be_compromised
else:
pass
elif initial_service.software_state == SoftwareState.OVERWHELMED:
if reference_service.software_state == SoftwareState.GOOD:
elif reference_service.software_state == SoftwareState.OVERWHELMED:
if final_service.software_state == SoftwareState.GOOD:
score += config_values.good_should_be_overwhelmed
elif reference_service.software_state == SoftwareState.PATCHING:
elif final_service.software_state == SoftwareState.PATCHING:
score += config_values.patching_should_be_overwhelmed
elif reference_service.software_state == SoftwareState.COMPROMISED:
elif final_service.software_state == SoftwareState.COMPROMISED:
score += config_values.compromised_should_be_overwhelmed
elif reference_service.software_state == SoftwareState.OVERWHELMED:
elif final_service.software_state == SoftwareState.OVERWHELMED:
score += config_values.overwhelmed
else:
pass
@@ -267,7 +264,6 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
"""
score = 0
final_node_file_system_state = final_node.file_system_state_actual
initial_node_file_system_state = initial_node.file_system_state_actual
reference_node_file_system_state = reference_node.file_system_state_actual
final_node_scanning_state = final_node.file_system_scanning
@@ -279,67 +275,67 @@ def score_node_file_system(final_node, initial_node, reference_node, config_valu
score += config_values.all_ok
else:
# We're different from the reference situation
# Need to compare initial and reference state of node (i.e. at every step)
if initial_node_file_system_state == FileSystemState.GOOD:
if reference_node_file_system_state == FileSystemState.REPAIRING:
# Need to compare reference and final state of node (i.e. at every step)
if reference_node_file_system_state == FileSystemState.GOOD:
if final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_good
elif reference_node_file_system_state == FileSystemState.RESTORING:
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_good
elif reference_node_file_system_state == FileSystemState.CORRUPT:
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_good
elif reference_node_file_system_state == FileSystemState.DESTROYED:
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_good
else:
pass
elif initial_node_file_system_state == FileSystemState.REPAIRING:
if reference_node_file_system_state == FileSystemState.GOOD:
elif reference_node_file_system_state == FileSystemState.REPAIRING:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_repairing
elif reference_node_file_system_state == FileSystemState.RESTORING:
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_repairing
elif reference_node_file_system_state == FileSystemState.CORRUPT:
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_repairing
elif reference_node_file_system_state == FileSystemState.DESTROYED:
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_repairing
elif reference_node_file_system_state == FileSystemState.REPAIRING:
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing
else:
pass
elif initial_node_file_system_state == FileSystemState.RESTORING:
if reference_node_file_system_state == FileSystemState.GOOD:
elif reference_node_file_system_state == FileSystemState.RESTORING:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_restoring
elif reference_node_file_system_state == FileSystemState.REPAIRING:
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_restoring
elif reference_node_file_system_state == FileSystemState.CORRUPT:
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_restoring
elif reference_node_file_system_state == FileSystemState.DESTROYED:
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_restoring
elif reference_node_file_system_state == FileSystemState.RESTORING:
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring
else:
pass
elif initial_node_file_system_state == FileSystemState.CORRUPT:
if reference_node_file_system_state == FileSystemState.GOOD:
elif reference_node_file_system_state == FileSystemState.CORRUPT:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_corrupt
elif reference_node_file_system_state == FileSystemState.REPAIRING:
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_corrupt
elif reference_node_file_system_state == FileSystemState.RESTORING:
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_corrupt
elif reference_node_file_system_state == FileSystemState.DESTROYED:
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed_should_be_corrupt
elif reference_node_file_system_state == FileSystemState.CORRUPT:
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt
else:
pass
elif initial_node_file_system_state == FileSystemState.DESTROYED:
if reference_node_file_system_state == FileSystemState.GOOD:
elif reference_node_file_system_state == FileSystemState.DESTROYED:
if final_node_file_system_state == FileSystemState.GOOD:
score += config_values.good_should_be_destroyed
elif reference_node_file_system_state == FileSystemState.REPAIRING:
elif final_node_file_system_state == FileSystemState.REPAIRING:
score += config_values.repairing_should_be_destroyed
elif reference_node_file_system_state == FileSystemState.RESTORING:
elif final_node_file_system_state == FileSystemState.RESTORING:
score += config_values.restoring_should_be_destroyed
elif reference_node_file_system_state == FileSystemState.CORRUPT:
elif final_node_file_system_state == FileSystemState.CORRUPT:
score += config_values.corrupt_should_be_destroyed
elif reference_node_file_system_state == FileSystemState.DESTROYED:
elif final_node_file_system_state == FileSystemState.DESTROYED:
score += config_values.destroyed
else:
pass