Merge remote-tracking branch 'origin/dev' into feature/915_PRI-31_Packaging_Deployment
# Conflicts: # docs/source/about.rst # docs/source/config.rst # src/primaite/common/config_values_main.py # src/primaite/environment/primaite_env.py # src/primaite/main.py # tests/config/multidiscrete_obs_space_laydown_config.yaml # tests/config/obs_tests/laydown.yaml # tests/conftest.py # tests/test_observation_space.py
This commit is contained in:
@@ -182,16 +182,13 @@ All ACL rules are considered when applying an IER. Logic follows the order of ru
|
|||||||
|
|
||||||
Observation Spaces
|
Observation Spaces
|
||||||
******************
|
******************
|
||||||
|
The observation space provides the blue agent with information about the current status of nodes and links.
|
||||||
|
|
||||||
The OpenAI Gym observation space provides the status of all nodes and links across the whole system:
|
PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted.
|
||||||
|
|
||||||
* Nodes (in terms of hardware state, Software State, file system state and services state)
|
NodeLinkTable component
|
||||||
* Links (in terms of current loading for each service/protocol)
|
-----------------------
|
||||||
|
For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below:
|
||||||
The observation space can be configured as a ``gym.spaces.Box`` or ``gym.spaces.MultiDiscrete``, by setting the ``OBSERVATIONS`` parameter in the laydown config.
|
|
||||||
|
|
||||||
Box-type observation space
|
|
||||||
--------------------------
|
|
||||||
|
|
||||||
An example observation space is provided below:
|
An example observation space is provided below:
|
||||||
|
|
||||||
@@ -249,8 +246,6 @@ An example observation space is provided below:
|
|||||||
- 5000
|
- 5000
|
||||||
- 0
|
- 0
|
||||||
|
|
||||||
The observation space is a 6 x 6 Box type (OpenAI Gym Space) in this example. This is made up from the node and link information detailed below.
|
|
||||||
|
|
||||||
For the nodes, the following values are represented:
|
For the nodes, the following values are represented:
|
||||||
|
|
||||||
* ID
|
* ID
|
||||||
@@ -290,9 +285,9 @@ For the links, the following statuses are represented:
|
|||||||
* SoftwareState = N/A
|
* SoftwareState = N/A
|
||||||
* Protocol = loading in bits/s
|
* Protocol = loading in bits/s
|
||||||
|
|
||||||
MultiDiscrete-type observation space
|
NodeStatus component
|
||||||
------------------------------------
|
----------------------
|
||||||
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by integers.
|
This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states, represented by integers.
|
||||||
The example above would have the following structure:
|
The example above would have the following structure:
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
@@ -301,9 +296,6 @@ The example above would have the following structure:
|
|||||||
node1_info
|
node1_info
|
||||||
node2_info
|
node2_info
|
||||||
node3_info
|
node3_info
|
||||||
link1_status
|
|
||||||
link2_status
|
|
||||||
link3_status
|
|
||||||
]
|
]
|
||||||
|
|
||||||
Each ``node_info`` contains the following:
|
Each ``node_info`` contains the following:
|
||||||
@@ -318,7 +310,25 @@ Each ``node_info`` contains the following:
|
|||||||
service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
|
service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
|
||||||
]
|
]
|
||||||
|
|
||||||
Each ``link_status`` is just a number from 0-4 representing the network load in relation to bandwidth.
|
In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
|
||||||
|
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4])
|
||||||
|
|
||||||
|
LinkTrafficLevels
|
||||||
|
-----------------
|
||||||
|
This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value.
|
||||||
|
The number of bins can be customised with 5 being the default. It has the following strucutre:
|
||||||
|
.. code-block::
|
||||||
|
|
||||||
|
[
|
||||||
|
link1_status
|
||||||
|
link2_status
|
||||||
|
link3_status
|
||||||
|
]
|
||||||
|
|
||||||
|
Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth.
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
@@ -328,11 +338,11 @@ Each ``link_status`` is just a number from 0-4 representing the network load in
|
|||||||
3 = high traffic (<100%)
|
3 = high traffic (<100%)
|
||||||
4 = max traffic/ overwhelmed (100%)
|
4 = max traffic/ overwhelmed (100%)
|
||||||
|
|
||||||
The full observation space would have 15 node-related elements and 3 link-related elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
|
If the network has three links, the full observation space would have 3 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
|
||||||
|
|
||||||
.. code-block::
|
.. code-block::
|
||||||
|
|
||||||
gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5])
|
gym.spaces.MultiDiscrete([5,5,5])
|
||||||
|
|
||||||
Action Spaces
|
Action Spaces
|
||||||
**************
|
**************
|
||||||
|
|||||||
@@ -296,6 +296,34 @@ The Lay Down Config
|
|||||||
|
|
||||||
The lay down config file consists of the following attributes:
|
The lay down config file consists of the following attributes:
|
||||||
|
|
||||||
|
* **itemType: ACTIONS** [enum]
|
||||||
|
|
||||||
|
Determines whether a NODE or ACL action space format is adopted for the session
|
||||||
|
|
||||||
|
* **itemType: OBSERVATION_SPACE** [dict]
|
||||||
|
|
||||||
|
Allows for user to configure observation space by combining one or more observation components. List of available
|
||||||
|
components is is :py:mod:'primaite.environment.observations'.
|
||||||
|
|
||||||
|
The observation space config item should have a ``components`` key which is a list of components. Each component
|
||||||
|
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
|
||||||
|
component while it is being initialised.
|
||||||
|
|
||||||
|
This example illustrates the correct format for the observation space config item
|
||||||
|
|
||||||
|
.. code-block::yaml
|
||||||
|
|
||||||
|
- itemType: OBSERVATION_SPACE
|
||||||
|
components:
|
||||||
|
- name: LINK_TRAFFIC_LEVELS
|
||||||
|
options:
|
||||||
|
combine_service_traffic: false
|
||||||
|
quantisation_levels: 8
|
||||||
|
- name: NODE_STATUSES
|
||||||
|
- name: LINK_TRAFFIC_LEVELS
|
||||||
|
|
||||||
|
* **itemType: STEPS** [int]
|
||||||
|
|
||||||
* **item_type: PORTS** [int]
|
* **item_type: PORTS** [int]
|
||||||
|
|
||||||
Provides a list of ports modelled in this session
|
Provides a list of ports modelled in this session
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
[pytest]
|
[pytest]
|
||||||
testpaths =
|
testpaths =
|
||||||
tests
|
tests
|
||||||
|
markers =
|
||||||
|
env_config_paths
|
||||||
|
|||||||
91
src/primaite/common/config_values_main.py
Normal file
91
src/primaite/common/config_values_main.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||||
|
"""The config class."""
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigValuesMain(object):
|
||||||
|
"""Class to hold main config values."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Init."""
|
||||||
|
# Generic
|
||||||
|
self.agent_identifier = "" # the agent in use
|
||||||
|
self.observation_config = None # observation space config
|
||||||
|
self.num_episodes = 0 # number of episodes to train over
|
||||||
|
self.num_steps = 0 # number of steps in an episode
|
||||||
|
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only
|
||||||
|
self.config_filename_use_case = "" # the filename for the Use Case config file
|
||||||
|
self.session_type = "" # the session type to run (TRAINING or EVALUATION)
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
self.observation_space_high_value = (
|
||||||
|
0 # The high value for the observation space
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
self.all_ok = 0
|
||||||
|
# Node Hardware State
|
||||||
|
self.off_should_be_on = 0
|
||||||
|
self.off_should_be_resetting = 0
|
||||||
|
self.on_should_be_off = 0
|
||||||
|
self.on_should_be_resetting = 0
|
||||||
|
self.resetting_should_be_on = 0
|
||||||
|
self.resetting_should_be_off = 0
|
||||||
|
self.resetting = 0
|
||||||
|
# Node Software or Service State
|
||||||
|
self.good_should_be_patching = 0
|
||||||
|
self.good_should_be_compromised = 0
|
||||||
|
self.good_should_be_overwhelmed = 0
|
||||||
|
self.patching_should_be_good = 0
|
||||||
|
self.patching_should_be_compromised = 0
|
||||||
|
self.patching_should_be_overwhelmed = 0
|
||||||
|
self.patching = 0
|
||||||
|
self.compromised_should_be_good = 0
|
||||||
|
self.compromised_should_be_patching = 0
|
||||||
|
self.compromised_should_be_overwhelmed = 0
|
||||||
|
self.compromised = 0
|
||||||
|
self.overwhelmed_should_be_good = 0
|
||||||
|
self.overwhelmed_should_be_patching = 0
|
||||||
|
self.overwhelmed_should_be_compromised = 0
|
||||||
|
self.overwhelmed = 0
|
||||||
|
# Node File System State
|
||||||
|
self.good_should_be_repairing = 0
|
||||||
|
self.good_should_be_restoring = 0
|
||||||
|
self.good_should_be_corrupt = 0
|
||||||
|
self.good_should_be_destroyed = 0
|
||||||
|
self.repairing_should_be_good = 0
|
||||||
|
self.repairing_should_be_restoring = 0
|
||||||
|
self.repairing_should_be_corrupt = 0
|
||||||
|
self.repairing_should_be_destroyed = (
|
||||||
|
0 # Repairing does not fix destroyed state - you need to restore
|
||||||
|
)
|
||||||
|
self.repairing = 0
|
||||||
|
self.restoring_should_be_good = 0
|
||||||
|
self.restoring_should_be_repairing = 0
|
||||||
|
self.restoring_should_be_corrupt = (
|
||||||
|
0 # Not the optimal method (as repair will fix corruption)
|
||||||
|
)
|
||||||
|
self.restoring_should_be_destroyed = 0
|
||||||
|
self.restoring = 0
|
||||||
|
self.corrupt_should_be_good = 0
|
||||||
|
self.corrupt_should_be_repairing = 0
|
||||||
|
self.corrupt_should_be_restoring = 0
|
||||||
|
self.corrupt_should_be_destroyed = 0
|
||||||
|
self.corrupt = 0
|
||||||
|
self.destroyed_should_be_good = 0
|
||||||
|
self.destroyed_should_be_repairing = 0
|
||||||
|
self.destroyed_should_be_restoring = 0
|
||||||
|
self.destroyed_should_be_corrupt = 0
|
||||||
|
self.destroyed = 0
|
||||||
|
self.scanning = 0
|
||||||
|
# IER status
|
||||||
|
self.red_ier_running = 0
|
||||||
|
self.green_ier_blocked = 0
|
||||||
|
|
||||||
|
# Patching / Reset
|
||||||
|
self.os_patching_duration = 0 # The time taken to patch the OS
|
||||||
|
self.node_reset_duration = 0 # The time taken to reset a node (hardware)
|
||||||
|
self.service_patching_duration = 0 # The time taken to patch a service
|
||||||
|
self.file_system_repairing_limit = 0 # The time take to repair a file
|
||||||
|
self.file_system_restoring_limit = 0 # The time take to restore a file
|
||||||
|
self.file_system_scanning_limit = 0 # The time taken to scan the file system
|
||||||
403
src/primaite/environment/observations.py
Normal file
403
src/primaite/environment/observations.py
Normal file
@@ -0,0 +1,403 @@
|
|||||||
|
"""Module for handling configurable observation spaces in PrimAITE."""
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from gym import spaces
|
||||||
|
|
||||||
|
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
|
||||||
|
from primaite.nodes.active_node import ActiveNode
|
||||||
|
from primaite.nodes.service_node import ServiceNode
|
||||||
|
|
||||||
|
# This dependency is only needed for type hints,
|
||||||
|
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
||||||
|
# Therefore, this avoids circular dependency problem.
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from primaite.environment.primaite_env import Primaite
|
||||||
|
|
||||||
|
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractObservationComponent(ABC):
|
||||||
|
"""Represents a part of the PrimAITE observation space."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def __init__(self, env: "Primaite"):
|
||||||
|
_LOGGER.info(f"Initialising {self} observation component")
|
||||||
|
self.env: "Primaite" = env
|
||||||
|
self.space: spaces.Space
|
||||||
|
self.current_observation: np.ndarray # type might be too restrictive?
|
||||||
|
return NotImplemented
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update(self):
|
||||||
|
"""Update the observation based on the current state of the environment."""
|
||||||
|
self.current_observation = NotImplemented
|
||||||
|
|
||||||
|
|
||||||
|
class NodeLinkTable(AbstractObservationComponent):
|
||||||
|
"""Table with nodes and links as rows and hardware/software status as cols.
|
||||||
|
|
||||||
|
This will create the observation space formatted as a table of integers.
|
||||||
|
There is one row per node, followed by one row per link.
|
||||||
|
The number of columns is 4 plus one per service. They are:
|
||||||
|
* node/link ID
|
||||||
|
* node hardware status / 0 for links
|
||||||
|
* node operating system status (if active/service) / 0 for links
|
||||||
|
* node file system status (active/service only) / 0 for links
|
||||||
|
* node service1 status / traffic load from that service for links
|
||||||
|
* node service2 status / traffic load from that service for links
|
||||||
|
* ...
|
||||||
|
* node serviceN status / traffic load from that service for links
|
||||||
|
|
||||||
|
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
||||||
|
``(12, 7)``
|
||||||
|
"""
|
||||||
|
|
||||||
|
_FIXED_PARAMETERS: int = 4
|
||||||
|
_MAX_VAL: int = 1_000_000
|
||||||
|
_DATA_TYPE: type = np.int64
|
||||||
|
|
||||||
|
def __init__(self, env: "Primaite"):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
# 1. Define the shape of your observation space component
|
||||||
|
num_items = self.env.num_links + self.env.num_nodes
|
||||||
|
num_columns = self.env.num_services + self._FIXED_PARAMETERS
|
||||||
|
observation_shape = (num_items, num_columns)
|
||||||
|
|
||||||
|
# 2. Create Observation space
|
||||||
|
self.space = spaces.Box(
|
||||||
|
low=0,
|
||||||
|
high=self._MAX_VAL,
|
||||||
|
shape=observation_shape,
|
||||||
|
dtype=self._DATA_TYPE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Initialise Observation with zeroes
|
||||||
|
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""Update the observation based on current environment state.
|
||||||
|
|
||||||
|
The structure of the observation space is described in :class:`.NodeLinkTable`
|
||||||
|
"""
|
||||||
|
item_index = 0
|
||||||
|
nodes = self.env.nodes
|
||||||
|
links = self.env.links
|
||||||
|
# Do nodes first
|
||||||
|
for _, node in nodes.items():
|
||||||
|
self.current_observation[item_index][0] = int(node.node_id)
|
||||||
|
self.current_observation[item_index][1] = node.hardware_state.value
|
||||||
|
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||||
|
self.current_observation[item_index][2] = node.software_state.value
|
||||||
|
self.current_observation[item_index][
|
||||||
|
3
|
||||||
|
] = node.file_system_state_observed.value
|
||||||
|
else:
|
||||||
|
self.current_observation[item_index][2] = 0
|
||||||
|
self.current_observation[item_index][3] = 0
|
||||||
|
service_index = 4
|
||||||
|
if isinstance(node, ServiceNode):
|
||||||
|
for service in self.env.services_list:
|
||||||
|
if node.has_service(service):
|
||||||
|
self.current_observation[item_index][
|
||||||
|
service_index
|
||||||
|
] = node.get_service_state(service).value
|
||||||
|
else:
|
||||||
|
self.current_observation[item_index][service_index] = 0
|
||||||
|
service_index += 1
|
||||||
|
else:
|
||||||
|
# Not a service node
|
||||||
|
for service in self.env.services_list:
|
||||||
|
self.current_observation[item_index][service_index] = 0
|
||||||
|
service_index += 1
|
||||||
|
item_index += 1
|
||||||
|
|
||||||
|
# Now do links
|
||||||
|
for _, link in links.items():
|
||||||
|
self.current_observation[item_index][0] = int(link.get_id())
|
||||||
|
self.current_observation[item_index][1] = 0
|
||||||
|
self.current_observation[item_index][2] = 0
|
||||||
|
self.current_observation[item_index][3] = 0
|
||||||
|
protocol_list = link.get_protocol_list()
|
||||||
|
protocol_index = 0
|
||||||
|
for protocol in protocol_list:
|
||||||
|
self.current_observation[item_index][
|
||||||
|
protocol_index + 4
|
||||||
|
] = protocol.get_load()
|
||||||
|
protocol_index += 1
|
||||||
|
item_index += 1
|
||||||
|
|
||||||
|
|
||||||
|
class NodeStatuses(AbstractObservationComponent):
|
||||||
|
"""Flat list of nodes' hardware, OS, file system, and service states.
|
||||||
|
|
||||||
|
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
||||||
|
integers.
|
||||||
|
Each node has 3 elements plus 1 per service. It will have the following structure:
|
||||||
|
.. code-block::
|
||||||
|
[
|
||||||
|
node1 hardware state,
|
||||||
|
node1 OS state,
|
||||||
|
node1 file system state,
|
||||||
|
node1 service1 state,
|
||||||
|
node1 service2 state,
|
||||||
|
node1 serviceN state (one for each service),
|
||||||
|
node2 hardware state,
|
||||||
|
node2 OS state,
|
||||||
|
node2 file system state,
|
||||||
|
node2 service1 state,
|
||||||
|
node2 service2 state,
|
||||||
|
node2 serviceN state (one for each service),
|
||||||
|
...
|
||||||
|
]
|
||||||
|
|
||||||
|
:param env: The environment that forms the basis of the observations
|
||||||
|
:type env: Primaite
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DATA_TYPE: type = np.int64
|
||||||
|
|
||||||
|
def __init__(self, env: "Primaite"):
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
# 1. Define the shape of your observation space component
|
||||||
|
node_shape = [
|
||||||
|
len(HardwareState) + 1,
|
||||||
|
len(SoftwareState) + 1,
|
||||||
|
len(FileSystemState) + 1,
|
||||||
|
]
|
||||||
|
services_shape = [len(SoftwareState) + 1] * self.env.num_services
|
||||||
|
node_shape = node_shape + services_shape
|
||||||
|
|
||||||
|
shape = node_shape * self.env.num_nodes
|
||||||
|
# 2. Create Observation space
|
||||||
|
self.space = spaces.MultiDiscrete(shape)
|
||||||
|
|
||||||
|
# 3. Initialise observation with zeroes
|
||||||
|
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""Update the observation based on current environment state.
|
||||||
|
|
||||||
|
The structure of the observation space is described in :class:`.NodeStatuses`
|
||||||
|
"""
|
||||||
|
obs = []
|
||||||
|
for _, node in self.env.nodes.items():
|
||||||
|
hardware_state = node.hardware_state.value
|
||||||
|
software_state = 0
|
||||||
|
file_system_state = 0
|
||||||
|
service_states = [0] * self.env.num_services
|
||||||
|
|
||||||
|
if isinstance(node, ActiveNode):
|
||||||
|
software_state = node.software_state.value
|
||||||
|
file_system_state = node.file_system_state_observed.value
|
||||||
|
|
||||||
|
if isinstance(node, ServiceNode):
|
||||||
|
for i, service in enumerate(self.env.services_list):
|
||||||
|
if node.has_service(service):
|
||||||
|
service_states[i] = node.get_service_state(service).value
|
||||||
|
obs.extend(
|
||||||
|
[hardware_state, software_state, file_system_state, *service_states]
|
||||||
|
)
|
||||||
|
self.current_observation[:] = obs
|
||||||
|
|
||||||
|
|
||||||
|
class LinkTrafficLevels(AbstractObservationComponent):
|
||||||
|
"""Flat list of traffic levels encoded into banded categories.
|
||||||
|
|
||||||
|
For each link, total traffic or traffic per service is encoded into a categorical value.
|
||||||
|
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
||||||
|
0 = No traffic (0% of bandwidth)
|
||||||
|
1 = No traffic (0%-33% of bandwidth)
|
||||||
|
2 = No traffic (33%-66% of bandwidth)
|
||||||
|
3 = No traffic (66%-100% of bandwidth)
|
||||||
|
4 = No traffic (100% of bandwidth)
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
|
||||||
|
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
|
||||||
|
|
||||||
|
:param env: The environment that forms the basis of the observations
|
||||||
|
:type env: Primaite
|
||||||
|
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
|
||||||
|
defaults to False
|
||||||
|
:type combine_service_traffic: bool, optional
|
||||||
|
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value ,
|
||||||
|
defaults to 5
|
||||||
|
:type quantisation_levels: int, optional
|
||||||
|
"""
|
||||||
|
|
||||||
|
_DATA_TYPE: type = np.int64
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
env: "Primaite",
|
||||||
|
combine_service_traffic: bool = False,
|
||||||
|
quantisation_levels: int = 5,
|
||||||
|
):
|
||||||
|
if quantisation_levels < 3:
|
||||||
|
_msg = (
|
||||||
|
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
|
||||||
|
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
|
||||||
|
f"Resetting to default value (5)"
|
||||||
|
)
|
||||||
|
_LOGGER.warning(_msg)
|
||||||
|
quantisation_levels = 5
|
||||||
|
|
||||||
|
super().__init__(env)
|
||||||
|
|
||||||
|
self._combine_service_traffic: bool = combine_service_traffic
|
||||||
|
self._quantisation_levels: int = quantisation_levels
|
||||||
|
self._entries_per_link: int = 1
|
||||||
|
|
||||||
|
if not self._combine_service_traffic:
|
||||||
|
self._entries_per_link = self.env.num_services
|
||||||
|
|
||||||
|
# 1. Define the shape of your observation space component
|
||||||
|
shape = (
|
||||||
|
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Create Observation space
|
||||||
|
self.space = spaces.MultiDiscrete(shape)
|
||||||
|
|
||||||
|
# 3. Initialise observation with zeroes
|
||||||
|
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||||
|
|
||||||
|
def update(self):
|
||||||
|
"""Update the observation based on current environment state.
|
||||||
|
|
||||||
|
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
||||||
|
"""
|
||||||
|
obs = []
|
||||||
|
for _, link in self.env.links.items():
|
||||||
|
bandwidth = link.bandwidth
|
||||||
|
if self._combine_service_traffic:
|
||||||
|
loads = [link.get_current_load()]
|
||||||
|
else:
|
||||||
|
loads = [protocol.get_load() for protocol in link.protocol_list]
|
||||||
|
|
||||||
|
for load in loads:
|
||||||
|
if load <= 0:
|
||||||
|
traffic_level = 0
|
||||||
|
elif load >= bandwidth:
|
||||||
|
traffic_level = self._quantisation_levels - 1
|
||||||
|
else:
|
||||||
|
traffic_level = (load / bandwidth) // (
|
||||||
|
1 / (self._quantisation_levels - 2)
|
||||||
|
) + 1
|
||||||
|
|
||||||
|
obs.append(int(traffic_level))
|
||||||
|
|
||||||
|
self.current_observation[:] = obs
|
||||||
|
|
||||||
|
|
||||||
|
class ObservationsHandler:
|
||||||
|
"""Component-based observation space handler.
|
||||||
|
|
||||||
|
This allows users to configure observation spaces by mixing and matching components.
|
||||||
|
Each component can also define further parameters to make them more flexible.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_REGISTRY: Final[Dict[str, type]] = {
|
||||||
|
"NODE_LINK_TABLE": NodeLinkTable,
|
||||||
|
"NODE_STATUSES": NodeStatuses,
|
||||||
|
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.registered_obs_components: List[AbstractObservationComponent] = []
|
||||||
|
self.space: spaces.Space
|
||||||
|
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
||||||
|
|
||||||
|
def update_obs(self):
|
||||||
|
"""Fetch fresh information about the environment."""
|
||||||
|
current_obs = []
|
||||||
|
for obs in self.registered_obs_components:
|
||||||
|
obs.update()
|
||||||
|
current_obs.append(obs.current_observation)
|
||||||
|
|
||||||
|
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
||||||
|
if len(current_obs) == 1:
|
||||||
|
self.current_observation = current_obs[0]
|
||||||
|
else:
|
||||||
|
self.current_observation = tuple(current_obs)
|
||||||
|
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||||
|
|
||||||
|
def register(self, obs_component: AbstractObservationComponent):
|
||||||
|
"""Add a component for this handler to track.
|
||||||
|
|
||||||
|
:param obs_component: The component to add.
|
||||||
|
:type obs_component: AbstractObservationComponent
|
||||||
|
"""
|
||||||
|
self.registered_obs_components.append(obs_component)
|
||||||
|
self.update_space()
|
||||||
|
|
||||||
|
def deregister(self, obs_component: AbstractObservationComponent):
|
||||||
|
"""Remove a component from this handler.
|
||||||
|
|
||||||
|
:param obs_component: Which component to remove. It must exist within this object's
|
||||||
|
``registered_obs_components`` attribute.
|
||||||
|
:type obs_component: AbstractObservationComponent
|
||||||
|
"""
|
||||||
|
self.registered_obs_components.remove(obs_component)
|
||||||
|
self.update_space()
|
||||||
|
|
||||||
|
def update_space(self):
|
||||||
|
"""Rebuild the handler's composite observation space from its components."""
|
||||||
|
component_spaces = []
|
||||||
|
for obs_comp in self.registered_obs_components:
|
||||||
|
component_spaces.append(obs_comp.space)
|
||||||
|
|
||||||
|
# If there is only one component, don't use a tuple space, just pass through that component's space.
|
||||||
|
if len(component_spaces) == 1:
|
||||||
|
self.space = component_spaces[0]
|
||||||
|
else:
|
||||||
|
self.space = spaces.Tuple(component_spaces)
|
||||||
|
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
||||||
|
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
||||||
|
|
||||||
|
The expected format for the config dictionary is:
|
||||||
|
|
||||||
|
..code-block::python
|
||||||
|
config = {
|
||||||
|
components: [
|
||||||
|
{
|
||||||
|
"name": "<COMPONENT1_NAME>"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "<COMPONENT2_NAME>"
|
||||||
|
"options": {"opt1": val1, "opt2": val2}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
...
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
:return: Observation handler
|
||||||
|
:rtype: primaite.environment.observations.ObservationsHandler
|
||||||
|
"""
|
||||||
|
# Instantiate the handler
|
||||||
|
handler = cls()
|
||||||
|
|
||||||
|
for component_cfg in obs_space_config["components"]:
|
||||||
|
# Figure out which class can instantiate the desired component
|
||||||
|
comp_type = component_cfg["name"]
|
||||||
|
comp_class = cls._REGISTRY[comp_type]
|
||||||
|
|
||||||
|
# Create the component with options from the YAML
|
||||||
|
options = component_cfg.get("options") or {}
|
||||||
|
component = comp_class(env, **options)
|
||||||
|
|
||||||
|
handler.register(component)
|
||||||
|
|
||||||
|
handler.update_obs()
|
||||||
|
return handler
|
||||||
@@ -24,11 +24,11 @@ from primaite.common.enums import (
|
|||||||
NodePOLInitiator,
|
NodePOLInitiator,
|
||||||
NodePOLType,
|
NodePOLType,
|
||||||
NodeType,
|
NodeType,
|
||||||
ObservationType,
|
|
||||||
Priority,
|
Priority,
|
||||||
SoftwareState,
|
SoftwareState,
|
||||||
)
|
)
|
||||||
from primaite.common.service import Service
|
from primaite.common.service import Service
|
||||||
|
from primaite.environment.observations import ObservationsHandler
|
||||||
from primaite.config import training_config
|
from primaite.config import training_config
|
||||||
from primaite.config.training_config import TrainingConfig
|
from primaite.config.training_config import TrainingConfig
|
||||||
from primaite.environment.reward import calculate_reward_function
|
from primaite.environment.reward import calculate_reward_function
|
||||||
@@ -51,14 +51,11 @@ _LOGGER.setLevel(logging.INFO)
|
|||||||
class Primaite(Env):
|
class Primaite(Env):
|
||||||
"""PRIMmary AI Training Evironment (Primaite) class."""
|
"""PRIMmary AI Training Evironment (Primaite) class."""
|
||||||
|
|
||||||
# Observation / Action Space contants
|
# Action Space contants
|
||||||
OBSERVATION_SPACE_FIXED_PARAMETERS = 4
|
ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5
|
||||||
ACTION_SPACE_NODE_PROPERTY_VALUES = 5
|
ACTION_SPACE_NODE_ACTION_VALUES: int = 4
|
||||||
ACTION_SPACE_NODE_ACTION_VALUES = 4
|
ACTION_SPACE_ACL_ACTION_VALUES: int = 3
|
||||||
ACTION_SPACE_ACL_ACTION_VALUES = 3
|
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
|
||||||
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
|
|
||||||
|
|
||||||
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -165,8 +162,18 @@ class Primaite(Env):
|
|||||||
# Number of ports - gets a value when config is loaded
|
# Number of ports - gets a value when config is loaded
|
||||||
self.num_ports = 0
|
self.num_ports = 0
|
||||||
|
|
||||||
# Observation type, by default box.
|
# The action type
|
||||||
self.observation_type = ObservationType.BOX
|
self.action_type = 0
|
||||||
|
|
||||||
|
# TODO fix up with TrainingConfig
|
||||||
|
# stores the observation config from the yaml, default is NODE_LINK_TABLE
|
||||||
|
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||||
|
if self.config_values.observation_config is not None:
|
||||||
|
self.obs_config = self.config_values.observation_config
|
||||||
|
|
||||||
|
# Observation Handler manages the user-configurable observation space.
|
||||||
|
# It will be initialised later.
|
||||||
|
self.obs_handler: ObservationsHandler
|
||||||
|
|
||||||
|
|
||||||
# Open the config file and build the environment laydown
|
# Open the config file and build the environment laydown
|
||||||
@@ -229,7 +236,7 @@ class Primaite(Env):
|
|||||||
self.action_dict = self.create_node_and_acl_action_dict()
|
self.action_dict = self.create_node_and_acl_action_dict()
|
||||||
self.action_space = spaces.Discrete(len(self.action_dict))
|
self.action_space = spaces.Discrete(len(self.action_dict))
|
||||||
else:
|
else:
|
||||||
_LOGGER.info(f"Invalid action type selected")
|
_LOGGER.info(f"Invalid action type selected: {self.training_config.action_type}")
|
||||||
# Set up a csv to store the results of the training
|
# Set up a csv to store the results of the training
|
||||||
try:
|
try:
|
||||||
header = ["Episode", "Average Reward"]
|
header = ["Episode", "Average Reward"]
|
||||||
@@ -424,9 +431,7 @@ class Primaite(Env):
|
|||||||
_action: The action space from the agent
|
_action: The action space from the agent
|
||||||
"""
|
"""
|
||||||
# At the moment, actions are only affecting nodes
|
# At the moment, actions are only affecting nodes
|
||||||
print("")
|
|
||||||
print(_action)
|
|
||||||
print(self.action_dict)
|
|
||||||
if self.training_config.action_type == ActionType.NODE:
|
if self.training_config.action_type == ActionType.NODE:
|
||||||
self.apply_actions_to_nodes(_action)
|
self.apply_actions_to_nodes(_action)
|
||||||
elif self.training_config.action_type == ActionType.ACL:
|
elif self.training_config.action_type == ActionType.ACL:
|
||||||
@@ -652,252 +657,20 @@ class Primaite(Env):
|
|||||||
else:
|
else:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
|
||||||
"""Initialise the observation space with the BOX option chosen.
|
|
||||||
|
|
||||||
This will create the observation space formatted as a table of integers.
|
|
||||||
There is one row per node, followed by one row per link.
|
|
||||||
Columns are as follows:
|
|
||||||
* node/link ID
|
|
||||||
* node hardware status / 0 for links
|
|
||||||
* node operating system status (if active/service) / 0 for links
|
|
||||||
* node file system status (active/service only) / 0 for links
|
|
||||||
* node service1 status / traffic load from that service for links
|
|
||||||
* node service2 status / traffic load from that service for links
|
|
||||||
* ...
|
|
||||||
* node serviceN status / traffic load from that service for links
|
|
||||||
|
|
||||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
|
||||||
``(12, 7)``
|
|
||||||
|
|
||||||
:return: Box gym observation
|
|
||||||
:rtype: gym.spaces.Box
|
|
||||||
:return: Initial observation with all entires set to 0
|
|
||||||
:rtype: numpy.Array
|
|
||||||
"""
|
|
||||||
_LOGGER.info("Observation space type BOX selected")
|
|
||||||
|
|
||||||
# 1. Determine observation shape from laydown
|
|
||||||
num_items = self.num_links + self.num_nodes
|
|
||||||
num_observation_parameters = (
|
|
||||||
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
|
|
||||||
)
|
|
||||||
observation_shape = (num_items, num_observation_parameters)
|
|
||||||
|
|
||||||
# 2. Create observation space & zeroed out sample from space.
|
|
||||||
observation_space = spaces.Box(
|
|
||||||
low=0,
|
|
||||||
high=self.OBSERVATION_SPACE_HIGH_VALUE,
|
|
||||||
shape=observation_shape,
|
|
||||||
dtype=np.int64,
|
|
||||||
)
|
|
||||||
initial_observation = np.zeros(observation_shape, dtype=np.int64)
|
|
||||||
|
|
||||||
return observation_space, initial_observation
|
|
||||||
|
|
||||||
def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
|
||||||
"""Initialise the observation space with the MULTIDISCRETE option chosen.
|
|
||||||
|
|
||||||
This will create the observation space with node observations followed by link observations.
|
|
||||||
Each node has 3 elements in the observation space plus 1 per service, more specifically:
|
|
||||||
* hardware state
|
|
||||||
* operating system state
|
|
||||||
* file system state
|
|
||||||
* service states (one per service)
|
|
||||||
Each link has one element in the observation space, corresponding to the traffic load,
|
|
||||||
it can take the following values:
|
|
||||||
0 = No traffic (0% of bandwidth)
|
|
||||||
1 = No traffic (0%-33% of bandwidth)
|
|
||||||
2 = No traffic (33%-66% of bandwidth)
|
|
||||||
3 = No traffic (66%-100% of bandwidth)
|
|
||||||
4 = No traffic (100% of bandwidth)
|
|
||||||
|
|
||||||
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
|
||||||
``(37,)``
|
|
||||||
|
|
||||||
:return: MultiDiscrete gym observation
|
|
||||||
:rtype: gym.spaces.MultiDiscrete
|
|
||||||
:return: Initial observation with all entires set to 0
|
|
||||||
:rtype: numpy.Array
|
|
||||||
"""
|
|
||||||
_LOGGER.info("Observation space MULTIDISCRETE selected")
|
|
||||||
|
|
||||||
# 1. Determine observation shape from laydown
|
|
||||||
node_obs_shape = [
|
|
||||||
len(HardwareState) + 1,
|
|
||||||
len(SoftwareState) + 1,
|
|
||||||
len(FileSystemState) + 1,
|
|
||||||
]
|
|
||||||
node_services = [len(SoftwareState) + 1] * self.num_services
|
|
||||||
node_obs_shape = node_obs_shape + node_services
|
|
||||||
# the magic number 5 refers to 5 states of quantisation of traffic amount.
|
|
||||||
# (zero, low, medium, high, fully utilised/overwhelmed)
|
|
||||||
link_obs_shape = [5] * self.num_links
|
|
||||||
observation_shape = node_obs_shape * self.num_nodes + link_obs_shape
|
|
||||||
|
|
||||||
# 2. Create observation space & zeroed out sample from space.
|
|
||||||
observation_space = spaces.MultiDiscrete(observation_shape)
|
|
||||||
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
|
|
||||||
|
|
||||||
return observation_space, initial_observation
|
|
||||||
|
|
||||||
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
|
||||||
"""Build the observation space based on network laydown and provide initial obs.
|
"""Create the environment's observation handler.
|
||||||
|
|
||||||
This method uses the object's `num_links`, `num_nodes`, `num_services`,
|
:return: The observation space, initial observation (zeroed out array with the correct shape)
|
||||||
`OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type`
|
:rtype: Tuple[spaces.Space, np.ndarray]
|
||||||
attributes to figure out the correct shape and format for the observation space.
|
|
||||||
|
|
||||||
:raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType`
|
|
||||||
:return: Gym observation space
|
|
||||||
:rtype: gym.spaces.Space
|
|
||||||
:return: Initial observation with all entires set to 0
|
|
||||||
:rtype: numpy.Array
|
|
||||||
"""
|
"""
|
||||||
if self.observation_type == ObservationType.BOX:
|
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
|
||||||
observation_space, initial_observation = self._init_box_observations()
|
|
||||||
return observation_space, initial_observation
|
|
||||||
elif self.observation_type == ObservationType.MULTIDISCRETE:
|
|
||||||
(
|
|
||||||
observation_space,
|
|
||||||
initial_observation,
|
|
||||||
) = self._init_multidiscrete_observations()
|
|
||||||
return observation_space, initial_observation
|
|
||||||
else:
|
|
||||||
errmsg = (
|
|
||||||
f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}"
|
|
||||||
f", got {self.observation_type} instead"
|
|
||||||
)
|
|
||||||
_LOGGER.error(errmsg)
|
|
||||||
raise ValueError(errmsg)
|
|
||||||
|
|
||||||
def _update_env_obs_box(self):
|
return self.obs_handler.space, self.obs_handler.current_observation
|
||||||
"""Update the environment's observation state based on the current status of nodes and links.
|
|
||||||
|
|
||||||
The structure of the observation space is described in :func:`~_init_box_observations`
|
|
||||||
This function can only be called if the observation space setting is set to BOX.
|
|
||||||
|
|
||||||
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
|
|
||||||
"""
|
|
||||||
assert self.observation_type == ObservationType.BOX
|
|
||||||
item_index = 0
|
|
||||||
|
|
||||||
# Do nodes first
|
|
||||||
for node_key, node in self.nodes.items():
|
|
||||||
self.env_obs[item_index][0] = int(node.node_id)
|
|
||||||
self.env_obs[item_index][1] = node.hardware_state.value
|
|
||||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
|
||||||
self.env_obs[item_index][2] = node.software_state.value
|
|
||||||
self.env_obs[item_index][3] = node.file_system_state_observed.value
|
|
||||||
else:
|
|
||||||
self.env_obs[item_index][2] = 0
|
|
||||||
self.env_obs[item_index][3] = 0
|
|
||||||
service_index = 4
|
|
||||||
if isinstance(node, ServiceNode):
|
|
||||||
for service in self.services_list:
|
|
||||||
if node.has_service(service):
|
|
||||||
self.env_obs[item_index][
|
|
||||||
service_index
|
|
||||||
] = node.get_service_state(service).value
|
|
||||||
else:
|
|
||||||
self.env_obs[item_index][service_index] = 0
|
|
||||||
service_index += 1
|
|
||||||
else:
|
|
||||||
# Not a service node
|
|
||||||
for service in self.services_list:
|
|
||||||
self.env_obs[item_index][service_index] = 0
|
|
||||||
service_index += 1
|
|
||||||
item_index += 1
|
|
||||||
|
|
||||||
# Now do links
|
|
||||||
for link_key, link in self.links.items():
|
|
||||||
self.env_obs[item_index][0] = int(link.get_id())
|
|
||||||
self.env_obs[item_index][1] = 0
|
|
||||||
self.env_obs[item_index][2] = 0
|
|
||||||
self.env_obs[item_index][3] = 0
|
|
||||||
protocol_list = link.get_protocol_list()
|
|
||||||
protocol_index = 0
|
|
||||||
for protocol in protocol_list:
|
|
||||||
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
|
|
||||||
protocol_index += 1
|
|
||||||
item_index += 1
|
|
||||||
|
|
||||||
def _update_env_obs_multidiscrete(self):
|
|
||||||
"""Update the environment's observation state based on the current status of nodes and links.
|
|
||||||
|
|
||||||
The structure of the observation space is described in :func:`~_init_multidiscrete_observations`
|
|
||||||
This function can only be called if the observation space setting is set to MULTIDISCRETE.
|
|
||||||
|
|
||||||
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
|
|
||||||
"""
|
|
||||||
assert self.observation_type == ObservationType.MULTIDISCRETE
|
|
||||||
obs = []
|
|
||||||
# 1. Set nodes
|
|
||||||
# Each node has the following variables in the observation space:
|
|
||||||
# - Hardware state
|
|
||||||
# - Software state
|
|
||||||
# - File System state
|
|
||||||
# - Service 1 state
|
|
||||||
# - Service 2 state
|
|
||||||
# - ...
|
|
||||||
# - Service N state
|
|
||||||
for node_key, node in self.nodes.items():
|
|
||||||
hardware_state = node.hardware_state.value
|
|
||||||
software_state = 0
|
|
||||||
file_system_state = 0
|
|
||||||
services_states = [0] * self.num_services
|
|
||||||
|
|
||||||
if isinstance(
|
|
||||||
node, ActiveNode
|
|
||||||
): # ServiceNode is a subclass of ActiveNode so no need to check that also
|
|
||||||
software_state = node.software_state.value
|
|
||||||
file_system_state = node.file_system_state_observed.value
|
|
||||||
|
|
||||||
if isinstance(node, ServiceNode):
|
|
||||||
for i, service in enumerate(self.services_list):
|
|
||||||
if node.has_service(service):
|
|
||||||
services_states[i] = node.get_service_state(service).value
|
|
||||||
|
|
||||||
obs.extend(
|
|
||||||
[
|
|
||||||
hardware_state,
|
|
||||||
software_state,
|
|
||||||
file_system_state,
|
|
||||||
*services_states,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Set links
|
|
||||||
# Each link has just one variable in the observation space, it represents the traffic amount
|
|
||||||
# In order for the space to be fully MultiDiscrete, the amount of
|
|
||||||
# traffic on each link is quantised into a few levels:
|
|
||||||
# 0: no traffic (0% of bandwidth)
|
|
||||||
# 1: low traffic (0-33% of bandwidth)
|
|
||||||
# 2: medium traffic (33-66% of bandwidth)
|
|
||||||
# 3: high traffic (66-100% of bandwidth)
|
|
||||||
# 4: max traffic/overloaded (100% of bandwidth)
|
|
||||||
|
|
||||||
for link_key, link in self.links.items():
|
|
||||||
bandwidth = link.bandwidth
|
|
||||||
load = link.get_current_load()
|
|
||||||
|
|
||||||
if load <= 0:
|
|
||||||
traffic_level = 0
|
|
||||||
elif load >= bandwidth:
|
|
||||||
traffic_level = 4
|
|
||||||
else:
|
|
||||||
traffic_level = (load / bandwidth) // (1 / 3) + 1
|
|
||||||
|
|
||||||
obs.append(int(traffic_level))
|
|
||||||
|
|
||||||
self.env_obs = np.asarray(obs)
|
|
||||||
|
|
||||||
def update_environent_obs(self):
|
def update_environent_obs(self):
|
||||||
"""Updates the observation space based on the node and link status."""
|
"""Updates the observation space based on the node and link status."""
|
||||||
if self.observation_type == ObservationType.BOX:
|
self.obs_handler.update_obs()
|
||||||
self._update_env_obs_box()
|
self.env_obs = self.obs_handler.current_observation
|
||||||
elif self.observation_type == ObservationType.MULTIDISCRETE:
|
|
||||||
self._update_env_obs_multidiscrete()
|
|
||||||
|
|
||||||
def load_lay_down_config(self):
|
def load_lay_down_config(self):
|
||||||
"""Loads config data in order to build the environment configuration."""
|
"""Loads config data in order to build the environment configuration."""
|
||||||
@@ -929,11 +702,9 @@ class Primaite(Env):
|
|||||||
elif item["item_type"] == "PORTS":
|
elif item["item_type"] == "PORTS":
|
||||||
# Create the list of ports
|
# Create the list of ports
|
||||||
self.create_ports_list(item)
|
self.create_ports_list(item)
|
||||||
elif item["item_type"] == "OBSERVATIONS":
|
|
||||||
# Get the observation information
|
|
||||||
self.get_observation_info(item)
|
|
||||||
else:
|
else:
|
||||||
# Do nothing (bad formatting)
|
item_type = item["item_type"]
|
||||||
|
_LOGGER.error(f"Invalid item_type: {item_type}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
_LOGGER.info("Environment configuration loaded")
|
_LOGGER.info("Environment configuration loaded")
|
||||||
@@ -1260,6 +1031,28 @@ class Primaite(Env):
|
|||||||
"""
|
"""
|
||||||
self.observation_type = ObservationType[observation_info["type"]]
|
self.observation_type = ObservationType[observation_info["type"]]
|
||||||
|
|
||||||
|
|
||||||
|
def get_action_info(self, action_info):
|
||||||
|
"""
|
||||||
|
Extracts action_info.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
item: A config data item representing action info
|
||||||
|
"""
|
||||||
|
self.action_type = ActionType[action_info["type"]]
|
||||||
|
|
||||||
|
def save_obs_config(self, obs_config: dict):
|
||||||
|
"""Cache the config for the observation space.
|
||||||
|
|
||||||
|
This is necessary as the observation space can't be built while reading the config,
|
||||||
|
it must be done after all the nodes, links, and services have been initialised.
|
||||||
|
|
||||||
|
:param obs_config: Parsed config relating to the observation space. The format is described in
|
||||||
|
:py:meth:`primaite.environment.observations.ObservationsHandler.from_config`
|
||||||
|
:type obs_config: dict
|
||||||
|
"""
|
||||||
|
self.obs_config = obs_config
|
||||||
|
|
||||||
def reset_environment(self):
|
def reset_environment(self):
|
||||||
"""
|
"""
|
||||||
# Resets environment.
|
# Resets environment.
|
||||||
|
|||||||
@@ -1,68 +0,0 @@
|
|||||||
- item_type: ACTIONS
|
|
||||||
type: NODE
|
|
||||||
- item_type: OBSERVATIONS
|
|
||||||
type: MULTIDISCRETE
|
|
||||||
- item_type: STEPS
|
|
||||||
steps: 5
|
|
||||||
- item_type: PORTS
|
|
||||||
ports_list:
|
|
||||||
- port: '80'
|
|
||||||
- item_type: SERVICES
|
|
||||||
service_list:
|
|
||||||
- name: TCP
|
|
||||||
|
|
||||||
########################################
|
|
||||||
# Nodes
|
|
||||||
- item_type: NODE
|
|
||||||
node_id: '1'
|
|
||||||
name: PC1
|
|
||||||
node_class: SERVICE
|
|
||||||
node_type: COMPUTER
|
|
||||||
priority: P5
|
|
||||||
hardware_state: 'ON'
|
|
||||||
ip_address: 192.168.1.1
|
|
||||||
software_state: GOOD
|
|
||||||
file_system_state: GOOD
|
|
||||||
services:
|
|
||||||
- name: TCP
|
|
||||||
port: '80'
|
|
||||||
state: GOOD
|
|
||||||
- item_type: NODE
|
|
||||||
node_id: '2'
|
|
||||||
name: SERVER
|
|
||||||
node_class: SERVICE
|
|
||||||
node_type: SERVER
|
|
||||||
priority: P5
|
|
||||||
hardware_state: 'ON'
|
|
||||||
ip_address: 192.168.1.2
|
|
||||||
software_state: GOOD
|
|
||||||
file_system_state: GOOD
|
|
||||||
services:
|
|
||||||
- name: TCP
|
|
||||||
port: '80'
|
|
||||||
state: GOOD
|
|
||||||
- item_type: NODE
|
|
||||||
node_id: '3'
|
|
||||||
name: SWITCH1
|
|
||||||
node_class: ACTIVE
|
|
||||||
node_type: SWITCH
|
|
||||||
priority: P2
|
|
||||||
hardware_state: 'ON'
|
|
||||||
ip_address: 192.168.1.3
|
|
||||||
software_state: GOOD
|
|
||||||
file_system_state: GOOD
|
|
||||||
|
|
||||||
########################################
|
|
||||||
# Links
|
|
||||||
- item_type: LINK
|
|
||||||
id: '4'
|
|
||||||
name: link1
|
|
||||||
bandwidth: 1000
|
|
||||||
source: '1'
|
|
||||||
destination: '3'
|
|
||||||
- item_type: LINK
|
|
||||||
id: '5'
|
|
||||||
name: link2
|
|
||||||
bandwidth: 1000
|
|
||||||
source: '3'
|
|
||||||
destination: '2'
|
|
||||||
@@ -1,15 +1,15 @@
|
|||||||
- item_type: ACTIONS
|
- item_type: ACTIONS
|
||||||
type: NODE
|
type: NODE
|
||||||
- item_type: OBSERVATIONS
|
|
||||||
type: BOX
|
|
||||||
- item_type: STEPS
|
- item_type: STEPS
|
||||||
steps: 5
|
steps: 5
|
||||||
- item_type: PORTS
|
- item_type: PORTS
|
||||||
ports_list:
|
ports_list:
|
||||||
- port: '80'
|
- port: '80'
|
||||||
|
- port: '53'
|
||||||
- item_type: SERVICES
|
- item_type: SERVICES
|
||||||
service_list:
|
service_list:
|
||||||
- name: TCP
|
- name: TCP
|
||||||
|
- name: UDP
|
||||||
|
|
||||||
########################################
|
########################################
|
||||||
# Nodes
|
# Nodes
|
||||||
@@ -21,12 +21,15 @@
|
|||||||
priority: P5
|
priority: P5
|
||||||
hardware_state: 'ON'
|
hardware_state: 'ON'
|
||||||
ip_address: 192.168.1.1
|
ip_address: 192.168.1.1
|
||||||
software_state: GOOD
|
software_state: COMPROMISED
|
||||||
file_system_state: GOOD
|
file_system_state: GOOD
|
||||||
services:
|
services:
|
||||||
- name: TCP
|
- name: TCP
|
||||||
port: '80'
|
port: '80'
|
||||||
state: GOOD
|
state: GOOD
|
||||||
|
- name: UDP
|
||||||
|
port: '53'
|
||||||
|
state: GOOD
|
||||||
- item_type: NODE
|
- item_type: NODE
|
||||||
node_id: '2'
|
node_id: '2'
|
||||||
name: SERVER
|
name: SERVER
|
||||||
@@ -41,6 +44,9 @@
|
|||||||
- name: TCP
|
- name: TCP
|
||||||
port: '80'
|
port: '80'
|
||||||
state: GOOD
|
state: GOOD
|
||||||
|
- name: UDP
|
||||||
|
port: '53'
|
||||||
|
state: OVERWHELMED
|
||||||
- item_type: NODE
|
- item_type: NODE
|
||||||
node_id: '3'
|
node_id: '3'
|
||||||
name: SWITCH1
|
name: SWITCH1
|
||||||
@@ -66,3 +72,33 @@
|
|||||||
bandwidth: 1000
|
bandwidth: 1000
|
||||||
source: '3'
|
source: '3'
|
||||||
destination: '2'
|
destination: '2'
|
||||||
|
|
||||||
|
#########################################
|
||||||
|
# IERS
|
||||||
|
- item_type: GREEN_IER
|
||||||
|
id: '5'
|
||||||
|
start_step: 0
|
||||||
|
end_step: 5
|
||||||
|
load: 999
|
||||||
|
protocol: TCP
|
||||||
|
port: '80'
|
||||||
|
source: '1'
|
||||||
|
destination: '2'
|
||||||
|
mission_criticality: 5
|
||||||
|
|
||||||
|
#########################################
|
||||||
|
# ACL Rules
|
||||||
|
- itemType: ACL_RULE
|
||||||
|
id: '6'
|
||||||
|
permission: ALLOW
|
||||||
|
source: 192.168.1.1
|
||||||
|
destination: 192.168.1.2
|
||||||
|
protocol: TCP
|
||||||
|
port: 80
|
||||||
|
- itemType: ACL_RULE
|
||||||
|
id: '7'
|
||||||
|
permission: ALLOW
|
||||||
|
source: 192.168.1.2
|
||||||
|
destination: 192.168.1.1
|
||||||
|
protocol: TCP
|
||||||
|
port: 80
|
||||||
96
tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml
Normal file
96
tests/config/obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
# Main Config File
|
||||||
|
|
||||||
|
# Generic config values
|
||||||
|
# Choose one of these (dependent on Agent being trained)
|
||||||
|
# "STABLE_BASELINES3_PPO"
|
||||||
|
# "STABLE_BASELINES3_A2C"
|
||||||
|
# "GENERIC"
|
||||||
|
agentIdentifier: NONE
|
||||||
|
# Number of episodes to run per session
|
||||||
|
observationSpace:
|
||||||
|
components:
|
||||||
|
- name: LINK_TRAFFIC_LEVELS
|
||||||
|
options:
|
||||||
|
combine_service_traffic: false
|
||||||
|
quantisation_levels: 8
|
||||||
|
|
||||||
|
numEpisodes: 1
|
||||||
|
# Time delay between steps (for generic agents)
|
||||||
|
timeDelay: 1
|
||||||
|
# Filename of the scenario / laydown
|
||||||
|
configFilename: one_node_states_on_off_lay_down_config.yaml
|
||||||
|
# Type of session to be run (TRAINING or EVALUATION)
|
||||||
|
sessionType: TRAINING
|
||||||
|
# Determine whether to load an agent from file
|
||||||
|
loadAgent: False
|
||||||
|
# File path and file name of agent if you're loading one in
|
||||||
|
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observationSpaceHighValue: 1_000_000_000
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
allOk: 0
|
||||||
|
# Node Hardware State
|
||||||
|
offShouldBeOn: -10
|
||||||
|
offShouldBeResetting: -5
|
||||||
|
onShouldBeOff: -2
|
||||||
|
onShouldBeResetting: -5
|
||||||
|
resettingShouldBeOn: -5
|
||||||
|
resettingShouldBeOff: -2
|
||||||
|
resetting: -3
|
||||||
|
# Node Software or Service State
|
||||||
|
goodShouldBePatching: 2
|
||||||
|
goodShouldBeCompromised: 5
|
||||||
|
goodShouldBeOverwhelmed: 5
|
||||||
|
patchingShouldBeGood: -5
|
||||||
|
patchingShouldBeCompromised: 2
|
||||||
|
patchingShouldBeOverwhelmed: 2
|
||||||
|
patching: -3
|
||||||
|
compromisedShouldBeGood: -20
|
||||||
|
compromisedShouldBePatching: -20
|
||||||
|
compromisedShouldBeOverwhelmed: -20
|
||||||
|
compromised: -20
|
||||||
|
overwhelmedShouldBeGood: -20
|
||||||
|
overwhelmedShouldBePatching: -20
|
||||||
|
overwhelmedShouldBeCompromised: -20
|
||||||
|
overwhelmed: -20
|
||||||
|
# Node File System State
|
||||||
|
goodShouldBeRepairing: 2
|
||||||
|
goodShouldBeRestoring: 2
|
||||||
|
goodShouldBeCorrupt: 5
|
||||||
|
goodShouldBeDestroyed: 10
|
||||||
|
repairingShouldBeGood: -5
|
||||||
|
repairingShouldBeRestoring: 2
|
||||||
|
repairingShouldBeCorrupt: 2
|
||||||
|
repairingShouldBeDestroyed: 0
|
||||||
|
repairing: -3
|
||||||
|
restoringShouldBeGood: -10
|
||||||
|
restoringShouldBeRepairing: -2
|
||||||
|
restoringShouldBeCorrupt: 1
|
||||||
|
restoringShouldBeDestroyed: 2
|
||||||
|
restoring: -6
|
||||||
|
corruptShouldBeGood: -10
|
||||||
|
corruptShouldBeRepairing: -10
|
||||||
|
corruptShouldBeRestoring: -10
|
||||||
|
corruptShouldBeDestroyed: 2
|
||||||
|
corrupt: -10
|
||||||
|
destroyedShouldBeGood: -20
|
||||||
|
destroyedShouldBeRepairing: -20
|
||||||
|
destroyedShouldBeRestoring: -20
|
||||||
|
destroyedShouldBeCorrupt: -20
|
||||||
|
destroyed: -20
|
||||||
|
scanning: -2
|
||||||
|
# IER status
|
||||||
|
redIerRunning: -5
|
||||||
|
greenIerBlocked: -10
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
osPatchingDuration: 5 # The time taken to patch the OS
|
||||||
|
nodeResetDuration: 5 # The time taken to reset a node (hardware)
|
||||||
|
servicePatchingDuration: 5 # The time taken to patch a service
|
||||||
|
fileSystemRepairingLimit: 5 # The time take to repair the file system
|
||||||
|
fileSystemRestoringLimit: 5 # The time take to restore the file system
|
||||||
|
fileSystemScanningLimit: 5 # The time taken to scan the file system
|
||||||
93
tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml
Normal file
93
tests/config/obs_tests/main_config_NODE_LINK_TABLE.yaml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Main Config File
|
||||||
|
|
||||||
|
# Generic config values
|
||||||
|
# Choose one of these (dependent on Agent being trained)
|
||||||
|
# "STABLE_BASELINES3_PPO"
|
||||||
|
# "STABLE_BASELINES3_A2C"
|
||||||
|
# "GENERIC"
|
||||||
|
agentIdentifier: NONE
|
||||||
|
# Number of episodes to run per session
|
||||||
|
observationSpace:
|
||||||
|
components:
|
||||||
|
- name: NODE_LINK_TABLE
|
||||||
|
|
||||||
|
numEpisodes: 1
|
||||||
|
# Time delay between steps (for generic agents)
|
||||||
|
timeDelay: 1
|
||||||
|
# Filename of the scenario / laydown
|
||||||
|
configFilename: one_node_states_on_off_lay_down_config.yaml
|
||||||
|
# Type of session to be run (TRAINING or EVALUATION)
|
||||||
|
sessionType: TRAINING
|
||||||
|
# Determine whether to load an agent from file
|
||||||
|
loadAgent: False
|
||||||
|
# File path and file name of agent if you're loading one in
|
||||||
|
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observationSpaceHighValue: 1_000_000_000
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
allOk: 0
|
||||||
|
# Node Hardware State
|
||||||
|
offShouldBeOn: -10
|
||||||
|
offShouldBeResetting: -5
|
||||||
|
onShouldBeOff: -2
|
||||||
|
onShouldBeResetting: -5
|
||||||
|
resettingShouldBeOn: -5
|
||||||
|
resettingShouldBeOff: -2
|
||||||
|
resetting: -3
|
||||||
|
# Node Software or Service State
|
||||||
|
goodShouldBePatching: 2
|
||||||
|
goodShouldBeCompromised: 5
|
||||||
|
goodShouldBeOverwhelmed: 5
|
||||||
|
patchingShouldBeGood: -5
|
||||||
|
patchingShouldBeCompromised: 2
|
||||||
|
patchingShouldBeOverwhelmed: 2
|
||||||
|
patching: -3
|
||||||
|
compromisedShouldBeGood: -20
|
||||||
|
compromisedShouldBePatching: -20
|
||||||
|
compromisedShouldBeOverwhelmed: -20
|
||||||
|
compromised: -20
|
||||||
|
overwhelmedShouldBeGood: -20
|
||||||
|
overwhelmedShouldBePatching: -20
|
||||||
|
overwhelmedShouldBeCompromised: -20
|
||||||
|
overwhelmed: -20
|
||||||
|
# Node File System State
|
||||||
|
goodShouldBeRepairing: 2
|
||||||
|
goodShouldBeRestoring: 2
|
||||||
|
goodShouldBeCorrupt: 5
|
||||||
|
goodShouldBeDestroyed: 10
|
||||||
|
repairingShouldBeGood: -5
|
||||||
|
repairingShouldBeRestoring: 2
|
||||||
|
repairingShouldBeCorrupt: 2
|
||||||
|
repairingShouldBeDestroyed: 0
|
||||||
|
repairing: -3
|
||||||
|
restoringShouldBeGood: -10
|
||||||
|
restoringShouldBeRepairing: -2
|
||||||
|
restoringShouldBeCorrupt: 1
|
||||||
|
restoringShouldBeDestroyed: 2
|
||||||
|
restoring: -6
|
||||||
|
corruptShouldBeGood: -10
|
||||||
|
corruptShouldBeRepairing: -10
|
||||||
|
corruptShouldBeRestoring: -10
|
||||||
|
corruptShouldBeDestroyed: 2
|
||||||
|
corrupt: -10
|
||||||
|
destroyedShouldBeGood: -20
|
||||||
|
destroyedShouldBeRepairing: -20
|
||||||
|
destroyedShouldBeRestoring: -20
|
||||||
|
destroyedShouldBeCorrupt: -20
|
||||||
|
destroyed: -20
|
||||||
|
scanning: -2
|
||||||
|
# IER status
|
||||||
|
redIerRunning: -5
|
||||||
|
greenIerBlocked: -10
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
osPatchingDuration: 5 # The time taken to patch the OS
|
||||||
|
nodeResetDuration: 5 # The time taken to reset a node (hardware)
|
||||||
|
servicePatchingDuration: 5 # The time taken to patch a service
|
||||||
|
fileSystemRepairingLimit: 5 # The time take to repair the file system
|
||||||
|
fileSystemRestoringLimit: 5 # The time take to restore the file system
|
||||||
|
fileSystemScanningLimit: 5 # The time taken to scan the file system
|
||||||
93
tests/config/obs_tests/main_config_NODE_STATUSES.yaml
Normal file
93
tests/config/obs_tests/main_config_NODE_STATUSES.yaml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Main Config File
|
||||||
|
|
||||||
|
# Generic config values
|
||||||
|
# Choose one of these (dependent on Agent being trained)
|
||||||
|
# "STABLE_BASELINES3_PPO"
|
||||||
|
# "STABLE_BASELINES3_A2C"
|
||||||
|
# "GENERIC"
|
||||||
|
agentIdentifier: NONE
|
||||||
|
# Number of episodes to run per session
|
||||||
|
observationSpace:
|
||||||
|
components:
|
||||||
|
- name: NODE_STATUSES
|
||||||
|
|
||||||
|
numEpisodes: 1
|
||||||
|
# Time delay between steps (for generic agents)
|
||||||
|
timeDelay: 1
|
||||||
|
# Filename of the scenario / laydown
|
||||||
|
configFilename: one_node_states_on_off_lay_down_config.yaml
|
||||||
|
# Type of session to be run (TRAINING or EVALUATION)
|
||||||
|
sessionType: TRAINING
|
||||||
|
# Determine whether to load an agent from file
|
||||||
|
loadAgent: False
|
||||||
|
# File path and file name of agent if you're loading one in
|
||||||
|
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observationSpaceHighValue: 1_000_000_000
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
allOk: 0
|
||||||
|
# Node Hardware State
|
||||||
|
offShouldBeOn: -10
|
||||||
|
offShouldBeResetting: -5
|
||||||
|
onShouldBeOff: -2
|
||||||
|
onShouldBeResetting: -5
|
||||||
|
resettingShouldBeOn: -5
|
||||||
|
resettingShouldBeOff: -2
|
||||||
|
resetting: -3
|
||||||
|
# Node Software or Service State
|
||||||
|
goodShouldBePatching: 2
|
||||||
|
goodShouldBeCompromised: 5
|
||||||
|
goodShouldBeOverwhelmed: 5
|
||||||
|
patchingShouldBeGood: -5
|
||||||
|
patchingShouldBeCompromised: 2
|
||||||
|
patchingShouldBeOverwhelmed: 2
|
||||||
|
patching: -3
|
||||||
|
compromisedShouldBeGood: -20
|
||||||
|
compromisedShouldBePatching: -20
|
||||||
|
compromisedShouldBeOverwhelmed: -20
|
||||||
|
compromised: -20
|
||||||
|
overwhelmedShouldBeGood: -20
|
||||||
|
overwhelmedShouldBePatching: -20
|
||||||
|
overwhelmedShouldBeCompromised: -20
|
||||||
|
overwhelmed: -20
|
||||||
|
# Node File System State
|
||||||
|
goodShouldBeRepairing: 2
|
||||||
|
goodShouldBeRestoring: 2
|
||||||
|
goodShouldBeCorrupt: 5
|
||||||
|
goodShouldBeDestroyed: 10
|
||||||
|
repairingShouldBeGood: -5
|
||||||
|
repairingShouldBeRestoring: 2
|
||||||
|
repairingShouldBeCorrupt: 2
|
||||||
|
repairingShouldBeDestroyed: 0
|
||||||
|
repairing: -3
|
||||||
|
restoringShouldBeGood: -10
|
||||||
|
restoringShouldBeRepairing: -2
|
||||||
|
restoringShouldBeCorrupt: 1
|
||||||
|
restoringShouldBeDestroyed: 2
|
||||||
|
restoring: -6
|
||||||
|
corruptShouldBeGood: -10
|
||||||
|
corruptShouldBeRepairing: -10
|
||||||
|
corruptShouldBeRestoring: -10
|
||||||
|
corruptShouldBeDestroyed: 2
|
||||||
|
corrupt: -10
|
||||||
|
destroyedShouldBeGood: -20
|
||||||
|
destroyedShouldBeRepairing: -20
|
||||||
|
destroyedShouldBeRestoring: -20
|
||||||
|
destroyedShouldBeCorrupt: -20
|
||||||
|
destroyed: -20
|
||||||
|
scanning: -2
|
||||||
|
# IER status
|
||||||
|
redIerRunning: -5
|
||||||
|
greenIerBlocked: -10
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
osPatchingDuration: 5 # The time taken to patch the OS
|
||||||
|
nodeResetDuration: 5 # The time taken to reset a node (hardware)
|
||||||
|
servicePatchingDuration: 5 # The time taken to patch a service
|
||||||
|
fileSystemRepairingLimit: 5 # The time take to repair the file system
|
||||||
|
fileSystemRestoringLimit: 5 # The time take to restore the file system
|
||||||
|
fileSystemScanningLimit: 5 # The time taken to scan the file system
|
||||||
89
tests/config/obs_tests/main_config_without_obs.yaml
Normal file
89
tests/config/obs_tests/main_config_without_obs.yaml
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
# Main Config File
|
||||||
|
|
||||||
|
# Generic config values
|
||||||
|
# Choose one of these (dependent on Agent being trained)
|
||||||
|
# "STABLE_BASELINES3_PPO"
|
||||||
|
# "STABLE_BASELINES3_A2C"
|
||||||
|
# "GENERIC"
|
||||||
|
agentIdentifier: NONE
|
||||||
|
# Number of episodes to run per session
|
||||||
|
numEpisodes: 1
|
||||||
|
# Time delay between steps (for generic agents)
|
||||||
|
timeDelay: 1
|
||||||
|
# Filename of the scenario / laydown
|
||||||
|
configFilename: one_node_states_on_off_lay_down_config.yaml
|
||||||
|
# Type of session to be run (TRAINING or EVALUATION)
|
||||||
|
sessionType: TRAINING
|
||||||
|
# Determine whether to load an agent from file
|
||||||
|
loadAgent: False
|
||||||
|
# File path and file name of agent if you're loading one in
|
||||||
|
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||||
|
|
||||||
|
# Environment config values
|
||||||
|
# The high value for the observation space
|
||||||
|
observationSpaceHighValue: 1_000_000_000
|
||||||
|
|
||||||
|
# Reward values
|
||||||
|
# Generic
|
||||||
|
allOk: 0
|
||||||
|
# Node Hardware State
|
||||||
|
offShouldBeOn: -10
|
||||||
|
offShouldBeResetting: -5
|
||||||
|
onShouldBeOff: -2
|
||||||
|
onShouldBeResetting: -5
|
||||||
|
resettingShouldBeOn: -5
|
||||||
|
resettingShouldBeOff: -2
|
||||||
|
resetting: -3
|
||||||
|
# Node Software or Service State
|
||||||
|
goodShouldBePatching: 2
|
||||||
|
goodShouldBeCompromised: 5
|
||||||
|
goodShouldBeOverwhelmed: 5
|
||||||
|
patchingShouldBeGood: -5
|
||||||
|
patchingShouldBeCompromised: 2
|
||||||
|
patchingShouldBeOverwhelmed: 2
|
||||||
|
patching: -3
|
||||||
|
compromisedShouldBeGood: -20
|
||||||
|
compromisedShouldBePatching: -20
|
||||||
|
compromisedShouldBeOverwhelmed: -20
|
||||||
|
compromised: -20
|
||||||
|
overwhelmedShouldBeGood: -20
|
||||||
|
overwhelmedShouldBePatching: -20
|
||||||
|
overwhelmedShouldBeCompromised: -20
|
||||||
|
overwhelmed: -20
|
||||||
|
# Node File System State
|
||||||
|
goodShouldBeRepairing: 2
|
||||||
|
goodShouldBeRestoring: 2
|
||||||
|
goodShouldBeCorrupt: 5
|
||||||
|
goodShouldBeDestroyed: 10
|
||||||
|
repairingShouldBeGood: -5
|
||||||
|
repairingShouldBeRestoring: 2
|
||||||
|
repairingShouldBeCorrupt: 2
|
||||||
|
repairingShouldBeDestroyed: 0
|
||||||
|
repairing: -3
|
||||||
|
restoringShouldBeGood: -10
|
||||||
|
restoringShouldBeRepairing: -2
|
||||||
|
restoringShouldBeCorrupt: 1
|
||||||
|
restoringShouldBeDestroyed: 2
|
||||||
|
restoring: -6
|
||||||
|
corruptShouldBeGood: -10
|
||||||
|
corruptShouldBeRepairing: -10
|
||||||
|
corruptShouldBeRestoring: -10
|
||||||
|
corruptShouldBeDestroyed: 2
|
||||||
|
corrupt: -10
|
||||||
|
destroyedShouldBeGood: -20
|
||||||
|
destroyedShouldBeRepairing: -20
|
||||||
|
destroyedShouldBeRestoring: -20
|
||||||
|
destroyedShouldBeCorrupt: -20
|
||||||
|
destroyed: -20
|
||||||
|
scanning: -2
|
||||||
|
# IER status
|
||||||
|
redIerRunning: -5
|
||||||
|
greenIerBlocked: -10
|
||||||
|
|
||||||
|
# Patching / Reset durations
|
||||||
|
osPatchingDuration: 5 # The time taken to patch the OS
|
||||||
|
nodeResetDuration: 5 # The time taken to reset a node (hardware)
|
||||||
|
servicePatchingDuration: 5 # The time taken to patch a service
|
||||||
|
fileSystemRepairingLimit: 5 # The time take to repair the file system
|
||||||
|
fileSystemRestoringLimit: 5 # The time take to restore the file system
|
||||||
|
fileSystemScanningLimit: 5 # The time taken to scan the file system
|
||||||
@@ -1,36 +1,220 @@
|
|||||||
"""Test env creation and behaviour with different observation spaces."""
|
"""Test env creation and behaviour with different observation spaces."""
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from primaite.environment.observations import (
|
||||||
|
NodeLinkTable,
|
||||||
|
NodeStatuses,
|
||||||
|
ObservationsHandler,
|
||||||
|
)
|
||||||
|
from primaite.environment.primaite_env import Primaite
|
||||||
from tests import TEST_CONFIG_ROOT
|
from tests import TEST_CONFIG_ROOT
|
||||||
from tests.conftest import _get_primaite_env_from_config
|
from tests.conftest import _get_primaite_env_from_config
|
||||||
|
|
||||||
|
|
||||||
def test_creating_env_with_box_obs():
|
@pytest.fixture
|
||||||
"""Try creating env with box observation space."""
|
def env(request):
|
||||||
env = _get_primaite_env_from_config(
|
"""Build Primaite environment for integration tests of observation space."""
|
||||||
training_config_path=TEST_CONFIG_ROOT
|
marker = request.node.get_closest_marker("env_config_paths")
|
||||||
/ "one_node_states_on_off_main_config.yaml",
|
main_config_path = marker.args[0]["main_config_path"]
|
||||||
lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml",
|
lay_down_config_path = marker.args[0]["lay_down_config_path"]
|
||||||
|
env, _ = _get_primaite_env_from_config(
|
||||||
|
main_config_path=main_config_path,
|
||||||
|
lay_down_config_path=lay_down_config_path,
|
||||||
)
|
)
|
||||||
|
yield env
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.env_config_paths(
|
||||||
|
dict(
|
||||||
|
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||||
|
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
def test_default_obs_space(env: Primaite):
|
||||||
|
"""Create environment with no obs space defined in config and check that the default obs space was created."""
|
||||||
env.update_environent_obs()
|
env.update_environent_obs()
|
||||||
|
|
||||||
# we have three nodes and two links, with one service
|
components = env.obs_handler.registered_obs_components
|
||||||
# therefore the box observation space will have:
|
|
||||||
# * 5 columns (four fixed and one for the service)
|
assert len(components) == 1
|
||||||
# * 5 rows (3 nodes + 2 links)
|
assert isinstance(components[0], NodeLinkTable)
|
||||||
assert env.env_obs.shape == (5, 5)
|
|
||||||
|
|
||||||
|
|
||||||
def test_creating_env_with_multidiscrete_obs():
|
@pytest.mark.env_config_paths(
|
||||||
"""Try creating env with MultiDiscrete observation space."""
|
dict(
|
||||||
env = _get_primaite_env_from_config(
|
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
|
||||||
training_config_path=TEST_CONFIG_ROOT
|
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||||
/ "one_node_states_on_off_main_config.yaml",
|
|
||||||
lay_down_config_path=TEST_CONFIG_ROOT
|
|
||||||
/ "multidiscrete_obs_space_laydown_config.yaml",
|
|
||||||
)
|
)
|
||||||
env.update_environent_obs()
|
)
|
||||||
|
def test_registering_components(env: Primaite):
|
||||||
|
"""Test regitering and deregistering a component."""
|
||||||
|
handler = ObservationsHandler()
|
||||||
|
component = NodeStatuses(env)
|
||||||
|
handler.register(component)
|
||||||
|
assert component in handler.registered_obs_components
|
||||||
|
handler.deregister(component)
|
||||||
|
assert component not in handler.registered_obs_components
|
||||||
|
|
||||||
# we have three nodes and two links, with one service
|
|
||||||
# the nodes have hardware, OS, FS, and service, the links just have bandwidth,
|
@pytest.mark.env_config_paths(
|
||||||
# therefore we need 3*4 + 2 observations
|
dict(
|
||||||
assert env.env_obs.shape == (3 * 4 + 2,)
|
main_config_path=TEST_CONFIG_ROOT
|
||||||
|
/ "obs_tests/main_config_NODE_LINK_TABLE.yaml",
|
||||||
|
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
class TestNodeLinkTable:
|
||||||
|
"""Test the NodeLinkTable observation component (in isolation)."""
|
||||||
|
|
||||||
|
def test_obs_shape(self, env: Primaite):
|
||||||
|
"""Try creating env with box observation space."""
|
||||||
|
env.update_environent_obs()
|
||||||
|
|
||||||
|
# we have three nodes and two links, with two service
|
||||||
|
# therefore the box observation space will have:
|
||||||
|
# * 5 rows (3 nodes + 2 links)
|
||||||
|
# * 6 columns (four fixed and two for the services)
|
||||||
|
assert env.env_obs.shape == (5, 6)
|
||||||
|
|
||||||
|
def test_value(self, env: Primaite):
|
||||||
|
"""Test that the observation is generated correctly.
|
||||||
|
|
||||||
|
The laydown has:
|
||||||
|
* 3 nodes (2 service nodes and 1 active node)
|
||||||
|
* 2 services
|
||||||
|
* 2 links
|
||||||
|
|
||||||
|
Both nodes have both services, and all states are GOOD, therefore the expected observation value is:
|
||||||
|
|
||||||
|
* Node 1:
|
||||||
|
* 1 (id)
|
||||||
|
* 1 (good hardware state)
|
||||||
|
* 3 (compromised OS state)
|
||||||
|
* 1 (good file system state)
|
||||||
|
* 1 (good TCP state)
|
||||||
|
* 1 (good UDP state)
|
||||||
|
* Node 2:
|
||||||
|
* 2 (id)
|
||||||
|
* 1 (good hardware state)
|
||||||
|
* 1 (good OS state)
|
||||||
|
* 1 (good file system state)
|
||||||
|
* 1 (good TCP state)
|
||||||
|
* 4 (overwhelmed UDP state)
|
||||||
|
* Node 3 (active node):
|
||||||
|
* 3 (id)
|
||||||
|
* 1 (good hardware state)
|
||||||
|
* 1 (good OS state)
|
||||||
|
* 1 (good file system state)
|
||||||
|
* 0 (doesn't have service1)
|
||||||
|
* 0 (doesn't have service2)
|
||||||
|
* Link 1:
|
||||||
|
* 4 (id)
|
||||||
|
* 0 (n/a hardware state)
|
||||||
|
* 0 (n/a OS state)
|
||||||
|
* 0 (n/a file system state)
|
||||||
|
* 999 (999 traffic for service1)
|
||||||
|
* 0 (no traffic for service2)
|
||||||
|
* Link 2:
|
||||||
|
* 5 (id)
|
||||||
|
* 0 (good hardware state)
|
||||||
|
* 0 (good OS state)
|
||||||
|
* 0 (good file system state)
|
||||||
|
* 999 (999 traffic service1)
|
||||||
|
* 0 (no traffic for service2)
|
||||||
|
"""
|
||||||
|
# act = np.asarray([0,])
|
||||||
|
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
|
||||||
|
|
||||||
|
assert np.array_equal(
|
||||||
|
obs,
|
||||||
|
[
|
||||||
|
[1, 1, 3, 1, 1, 1],
|
||||||
|
[2, 1, 1, 1, 1, 4],
|
||||||
|
[3, 1, 1, 1, 0, 0],
|
||||||
|
[4, 0, 0, 0, 999, 0],
|
||||||
|
[5, 0, 0, 0, 999, 0],
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.env_config_paths(
|
||||||
|
dict(
|
||||||
|
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
|
||||||
|
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
class TestNodeStatuses:
|
||||||
|
"""Test the NodeStatuses observation component (in isolation)."""
|
||||||
|
|
||||||
|
def test_obs_shape(self, env: Primaite):
|
||||||
|
"""Try creating env with NodeStatuses as the only component."""
|
||||||
|
assert env.env_obs.shape == (15,)
|
||||||
|
|
||||||
|
def test_values(self, env: Primaite):
|
||||||
|
"""Test that the hardware and software states are encoded correctly.
|
||||||
|
|
||||||
|
The laydown has:
|
||||||
|
* one node with a compromised operating system state
|
||||||
|
* one node with two services, and the second service is overwhelmed.
|
||||||
|
* all other states are good or null
|
||||||
|
Therefore, the expected state is:
|
||||||
|
* node 1:
|
||||||
|
* hardware = good (1)
|
||||||
|
* OS = compromised (3)
|
||||||
|
* file system = good (1)
|
||||||
|
* service 1 = good (1)
|
||||||
|
* service 2 = good (1)
|
||||||
|
* node 2:
|
||||||
|
* hardware = good (1)
|
||||||
|
* OS = good (1)
|
||||||
|
* file system = good (1)
|
||||||
|
* service 1 = good (1)
|
||||||
|
* service 2 = overwhelmed (4)
|
||||||
|
* node 3 (switch):
|
||||||
|
* hardware = good (1)
|
||||||
|
* OS = good (1)
|
||||||
|
* file system = good (1)
|
||||||
|
* service 1 = n/a (0)
|
||||||
|
* service 2 = n/a (0)
|
||||||
|
"""
|
||||||
|
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
|
||||||
|
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.env_config_paths(
|
||||||
|
dict(
|
||||||
|
main_config_path=TEST_CONFIG_ROOT
|
||||||
|
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
|
||||||
|
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
class TestLinkTrafficLevels:
|
||||||
|
"""Test the LinkTrafficLevels observation component (in isolation)."""
|
||||||
|
|
||||||
|
def test_obs_shape(self, env: Primaite):
|
||||||
|
"""Try creating env with MultiDiscrete observation space."""
|
||||||
|
env.update_environent_obs()
|
||||||
|
|
||||||
|
# we have two links and two services, so the shape should be 2 * 2
|
||||||
|
assert env.env_obs.shape == (2 * 2,)
|
||||||
|
|
||||||
|
def test_values(self, env: Primaite):
|
||||||
|
"""Test that traffic values are encoded correctly.
|
||||||
|
|
||||||
|
The laydown has:
|
||||||
|
* two services
|
||||||
|
* three nodes
|
||||||
|
* two links
|
||||||
|
* an IER trying to send 999 bits of data over both links the whole time (via the first service)
|
||||||
|
* link bandwidth of 1000, therefore the utilisation is 99.9%
|
||||||
|
"""
|
||||||
|
obs, reward, done, info = env.step(0)
|
||||||
|
obs, reward, done, info = env.step(0)
|
||||||
|
|
||||||
|
# the observation space has combine_service_traffic set to False, so the space has this format:
|
||||||
|
# [link1_service1, link1_service2, link2_service1, link2_service2]
|
||||||
|
# we send 999 bits of data via link1 and link2 on service 1.
|
||||||
|
# therefore the first and third elements should be 6 and all others 0
|
||||||
|
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||||
|
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||||
|
|||||||
Reference in New Issue
Block a user