Merge remote-tracking branch 'origin/dev' into feature/898-Fix-the-functionality-of-resetting-a-node

This commit is contained in:
Chris McCarthy
2023-06-12 14:20:16 +01:00
23 changed files with 1758 additions and 420 deletions

View File

@@ -184,16 +184,13 @@ All ACL rules are considered when applying an IER. Logic follows the order of ru
Observation Spaces
******************
The observation space provides the blue agent with information about the current status of nodes and links.
The OpenAI Gym observation space provides the status of all nodes and links across the whole system:
PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted.
* Nodes (in terms of hardware state, Software State, file system state and services state)
* Links (in terms of current loading for each service/protocol)
The observation space can be configured as a ``gym.spaces.Box`` or ``gym.spaces.MultiDiscrete``, by setting the ``OBSERVATIONS`` parameter in the laydown config.
Box-type observation space
--------------------------
NodeLinkTable component
-----------------------
For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below:
An example observation space is provided below:
@@ -251,8 +248,6 @@ An example observation space is provided below:
- 5000
- 0
The observation space is a 6 x 6 Box type (OpenAI Gym Space) in this example. This is made up from the node and link information detailed below.
For the nodes, the following values are represented:
* ID
@@ -294,9 +289,9 @@ For the links, the following statuses are represented:
* SoftwareState = N/A
* Protocol = loading in bits/s
MultiDiscrete-type observation space
------------------------------------
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by integers.
NodeStatus component
----------------------
This is a MultiDiscrete observation space that can be though of as a one-dimensional vector of discrete states, represented by integers.
The example above would have the following structure:
.. code-block::
@@ -305,9 +300,6 @@ The example above would have the following structure:
node1_info
node2_info
node3_info
link1_status
link2_status
link3_status
]
Each ``node_info`` contains the following:
@@ -322,7 +314,25 @@ Each ``node_info`` contains the following:
service2_state (0=none, 1=GOOD, 2=PATCHING, 3=COMPROMISED)
]
Each ``link_status`` is just a number from 0-4 representing the network load in relation to bandwidth.
In a network with three nodes and two services, the full observation space would have 15 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
.. code-block::
gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4])
LinkTrafficLevels
-----------------
This component is a MultiDiscrete space showing the traffic flow levels on the links in the network, after applying a threshold to convert it from a continuous to a discrete value.
The number of bins can be customised with 5 being the default. It has the following strucutre:
.. code-block::
[
link1_status
link2_status
link3_status
]
Each ``link_status`` is a number from 0-4 representing the network load in relation to bandwidth.
.. code-block::
@@ -332,12 +342,11 @@ Each ``link_status`` is just a number from 0-4 representing the network load in
3 = high traffic (<100%)
4 = max traffic/ overwhelmed (100%)
The full observation space would have 15 node-related elements and 3 link-related elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
If the network has three links, the full observation space would have 3 elements. It can be written with ``gym`` notation to indicate the number of discrete options for each of the elements of the observation space. For example:
.. code-block::
gym.spaces.MultiDiscrete([4,5,6,4,4,4,5,6,4,4,4,5,6,4,4,5,5,5])
gym.spaces.MultiDiscrete([5,5,5])
Action Spaces
**************
@@ -346,29 +355,40 @@ The action space available to the blue agent comes in two types:
1. Node-based
2. Access Control List
3. Any (Agent can take both node-based and ACL-based actions)
The choice of action space used during a training session is determined in the config_[name].yaml file.
**Node-Based**
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym multidiscrete type, as follows:
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym spaces.Discrete type, as follows:
* [0, num nodes] - Node ID (0 = nothing, node ID)
* [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state)
* [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore)
* [0, num services] - Resolves to service ID (0 = nothing, resolves to service)
* Dictionary item {... ,1: [x1, x2, x3,x4] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, num nodes] - Node ID (0 = nothing, node ID)
* [0, 4] - What property it's acting on (0 = nothing, 1 = state, 2 = SoftwareState, 3 = service state, 4 = file system state)
* [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore)
* [0, num services] - Resolves to service ID (0 = nothing, resolves to service)
**Access Control List**
The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI multidiscrete type, as follows:
The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
* [0, 1] - Permission (0 = DENY, 1 = ALLOW)
* [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
* [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
* [0, 1] - Permission (0 = DENY, 1 = ALLOW)
* [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
* [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
* [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
**ANY**
The agent is able to carry out both **Node-Based** and **Access Control List** operations.
This means the dictionary will contain key-value pairs in the format of BOTH Node-Based and Access Control List as seen above.
Rewards
*******

View File

@@ -288,6 +288,28 @@ The config_[name].yaml file consists of the following attributes:
Determines whether a NODE or ACL action space format is adopted for the session
* **itemType: OBSERVATION_SPACE** [dict]
Allows for user to configure observation space by combining one or more observation components. List of available
components is is :py:mod:'primaite.environment.observations'.
The observation space config item should have a ``components`` key which is a list of components. Each component
config must have a ``name`` key, and can optionally have an ``options`` key. The ``options`` are passed to the
component while it is being initialised.
This example illustrates the correct format for the observation space config item
.. code-block::yaml
- itemType: OBSERVATION_SPACE
components:
- name: LINK_TRAFFIC_LEVELS
options:
combine_service_traffic: false
quantisation_levels: 8
- name: NODE_STATUSES
- name: LINK_TRAFFIC_LEVELS
* **itemType: STEPS** [int]
Determines the number of steps to run in each episode of the session

View File

@@ -1,3 +1,5 @@
[pytest]
testpaths =
tests
markers =
env_config_paths

View File

View File

@@ -0,0 +1,127 @@
from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
2
] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
return new_action
def transform_action_acl_readable(action):
"""
Transform an ACL action to a more readable format.
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == 0:
new_action[n + 2] = "ANY"
return new_action
def is_valid_node_action(action):
"""Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
"""
action_r = transform_action_node_readable(action)
node_property = action_r[1]
node_action = action_r[2]
# print("node property", node_property, "\nnode action", node_action)
if node_property == "NONE":
return False
if node_action == "NONE":
return False
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
# Software States can only do Nothing or Patch
return False
return True
def is_valid_acl_action(action):
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect.
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
"""
action_r = transform_action_acl_readable(action)
action_decision = action_r[0]
action_permission = action_r[1]
action_source_id = action_r[2]
action_destination_id = action_r[3]
if action_decision == "NONE":
return False
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
# ACL rule towards itself
return False
if action_permission == "DENY":
# DENY is unnecessary, we can create and delete allow rules instead
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
return False
return True
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action."""
if is_valid_acl_action(action) is False:
return False
action_r = transform_action_acl_readable(action)
action_protocol = action_r[4]
action_port = action_r[5]
# Don't allow protocols or ports to be ANY
# in the future we might want to do the opposite, and only have ANY option for ports and service
if action_protocol == "ANY":
return False
if action_port == "ANY":
return False
return True

View File

@@ -9,6 +9,7 @@ class ConfigValuesMain(object):
"""Init."""
# Generic
self.agent_identifier = "" # the agent in use
self.observation_config = None # observation space config
self.num_episodes = 0 # number of episodes to train over
self.num_steps = 0 # number of steps in an episode
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only

View File

@@ -51,6 +51,7 @@ class SoftwareState(Enum):
class NodePOLType(Enum):
"""Node Pattern of Life type enumeration."""
NONE = 0
OPERATING = 1
OS = 2
SERVICE = 3
@@ -83,6 +84,7 @@ class ActionType(Enum):
NODE = 0
ACL = 1
ANY = 2
class ObservationType(Enum):
@@ -100,3 +102,29 @@ class FileSystemState(Enum):
DESTROYED = 3
REPAIRING = 4
RESTORING = 5
class NodeHardwareAction(Enum):
"""Node hardware action."""
NONE = 0
ON = 1
OFF = 2
RESET = 3
class NodeSoftwareAction(Enum):
"""Node software action."""
NONE = 0
PATCHING = 1
class LinkStatus(Enum):
"""Link traffic status."""
NONE = 0
LOW = 1
MEDIUM = 2
HIGH = 3
OVERLOAD = 4

View File

@@ -0,0 +1,403 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
import numpy as np
from gym import spaces
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
# This dependency is only needed for type hints,
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
# Therefore, this avoids circular dependency problem.
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: "Primaite"):
_LOGGER.info(f"Initialising {self} observation component")
self.env: "Primaite" = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
return NotImplemented
@abstractmethod
def update(self):
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes and links as rows and hardware/software status as cols.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
The number of columns is 4 plus one per service. They are:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
"""
_FIXED_PARAMETERS: int = 4
_MAX_VAL: int = 1_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
super().__init__(env)
# 1. Define the shape of your observation space component
num_items = self.env.num_links + self.env.num_nodes
num_columns = self.env.num_services + self._FIXED_PARAMETERS
observation_shape = (num_items, num_columns)
# 2. Create Observation space
self.space = spaces.Box(
low=0,
high=self._MAX_VAL,
shape=observation_shape,
dtype=self._DATA_TYPE,
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeLinkTable`
"""
item_index = 0
nodes = self.env.nodes
links = self.env.links
# Do nodes first
for _, node in nodes.items():
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][
3
] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][
service_index
] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.env.services_list:
self.current_observation[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for _, link in links.items():
self.current_observation[item_index][0] = int(link.get_id())
self.current_observation[item_index][1] = 0
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][
protocol_index + 4
] = protocol.get_load()
protocol_index += 1
item_index += 1
class NodeStatuses(AbstractObservationComponent):
"""Flat list of nodes' hardware, OS, file system, and service states.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each node has 3 elements plus 1 per service. It will have the following structure:
.. code-block::
[
node1 hardware state,
node1 OS state,
node1 file system state,
node1 service1 state,
node1 service2 state,
node1 serviceN state (one for each service),
node2 hardware state,
node2 OS state,
node2 file system state,
node2 service1 state,
node2 service2 state,
node2 serviceN state (one for each service),
...
]
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
super().__init__(env)
# 1. Define the shape of your observation space component
node_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
services_shape = [len(SoftwareState) + 1] * self.env.num_services
node_shape = node_shape + services_shape
shape = node_shape * self.env.num_nodes
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeStatuses`
"""
obs = []
for _, node in self.env.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
service_states = [0] * self.env.num_services
if isinstance(node, ActiveNode):
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
obs.extend(
[hardware_state, software_state, file_system_state, *service_states]
)
self.current_observation[:] = obs
class LinkTrafficLevels(AbstractObservationComponent):
"""Flat list of traffic levels encoded into banded categories.
For each link, total traffic or traffic per service is encoded into a categorical value.
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
.. note::
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
:param env: The environment that forms the basis of the observations
:type env: Primaite
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
defaults to False
:type combine_service_traffic: bool, optional
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value ,
defaults to 5
:type quantisation_levels: int, optional
"""
_DATA_TYPE: type = np.int64
def __init__(
self,
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
if quantisation_levels < 3:
_msg = (
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
f"Resetting to default value (5)"
)
_LOGGER.warning(_msg)
quantisation_levels = 5
super().__init__(env)
self._combine_service_traffic: bool = combine_service_traffic
self._quantisation_levels: int = quantisation_levels
self._entries_per_link: int = 1
if not self._combine_service_traffic:
self._entries_per_link = self.env.num_services
# 1. Define the shape of your observation space component
shape = (
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
)
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.LinkTrafficLevels`
"""
obs = []
for _, link in self.env.links.items():
bandwidth = link.bandwidth
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [protocol.get_load() for protocol in link.protocol_list]
for load in loads:
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (
1 / (self._quantisation_levels - 2)
) + 1
obs.append(int(traffic_level))
self.current_observation[:] = obs
class ObservationsHandler:
"""Component-based observation space handler.
This allows users to configure observation spaces by mixing and matching components.
Each component can also define further parameters to make them more flexible.
"""
_REGISTRY: Final[Dict[str, type]] = {
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
}
def __init__(self):
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
def update_obs(self):
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
obs.update()
current_obs.append(obs.current_observation)
# If there is only one component, don't use a tuple, just pass through that component's obs.
if len(current_obs) == 1:
self.current_observation = current_obs[0]
else:
self.current_observation = tuple(current_obs)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track.
:param obs_component: The component to add.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
"""Remove a component from this handler.
:param obs_component: Which component to remove. It must exist within this object's
``registered_obs_components`` attribute.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
# If there is only one component, don't use a tuple space, just pass through that component's space.
if len(component_spaces) == 1:
self.space = component_spaces[0]
else:
self.space = spaces.Tuple(component_spaces)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
The expected format for the config dictionary is:
..code-block::python
config = {
components: [
{
"name": "<COMPONENT1_NAME>"
},
{
"name": "<COMPONENT2_NAME>"
"options": {"opt1": val1, "opt2": val2}
},
{
...
},
]
}
:return: Observation handler
:rtype: primaite.environment.observations.ObservationsHandler
"""
# Instantiate the handler
handler = cls()
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
comp_class = cls._REGISTRY[comp_type]
# Create the component with options from the YAML
options = component_cfg.get("options") or {}
component = comp_class(env, **options)
handler.register(component)
handler.update_obs()
return handler

View File

@@ -15,6 +15,7 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -23,11 +24,11 @@ from primaite.common.enums import (
NodePOLInitiator,
NodePOLType,
NodeType,
ObservationType,
Priority,
SoftwareState,
)
from primaite.common.service import Service
from primaite.environment.observations import ObservationsHandler
from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
@@ -42,19 +43,17 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
from primaite.transactions.transaction import Transaction
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
class Primaite(Env):
"""PRIMmary AI Training Evironment (Primaite) class."""
# Observation / Action Space contants
OBSERVATION_SPACE_FIXED_PARAMETERS = 4
ACTION_SPACE_NODE_PROPERTY_VALUES = 5
ACTION_SPACE_NODE_ACTION_VALUES = 4
ACTION_SPACE_ACL_ACTION_VALUES = 3
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
# Action Space contants
ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5
ACTION_SPACE_NODE_ACTION_VALUES: int = 4
ACTION_SPACE_ACL_ACTION_VALUES: int = 3
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__(self, _config_values, _transaction_list):
"""
@@ -149,8 +148,14 @@ class Primaite(Env):
# The action type
self.action_type = 0
# Observation type, by default box.
self.observation_type = ObservationType.BOX
# stores the observation config from the yaml, default is NODE_LINK_TABLE
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
if self.config_values.observation_config is not None:
self.obs_config = self.config_values.observation_config
# Observation Handler manages the user-configurable observation space.
# It will be initialised later.
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
try:
@@ -202,15 +207,9 @@ class Primaite(Env):
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_space = spaces.MultiDiscrete(
[
self.num_nodes,
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
self.ACTION_SPACE_NODE_ACTION_VALUES,
self.num_services,
]
)
else:
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
@@ -219,17 +218,14 @@ class Primaite(Env):
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_space = spaces.MultiDiscrete(
[
self.ACTION_SPACE_ACL_ACTION_VALUES,
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
self.num_nodes + 1,
self.num_nodes + 1,
self.num_services + 1,
self.num_ports + 1,
]
)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info("Invalid action type selected")
# Set up a csv to store the results of the training
try:
now = datetime.now() # current date and time
@@ -368,14 +364,14 @@ class Primaite(Env):
# 5. Calculate reward signal (for RL)
reward = calculate_reward_function(
self.nodes_post_pol,
self.nodes_post_blue,
self.nodes_post_red,
self.nodes_reference,
self.green_iers,
self.red_iers,
self.step_count,
self.config_values,
)
# print(f" Step {self.step_count} Reward: {str(reward)}")
print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -432,8 +428,18 @@ class Primaite(Env):
# At the moment, actions are only affecting nodes
if self.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
else:
elif self.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
"""
@@ -442,10 +448,11 @@ class Primaite(Env):
Args:
_action: The action space from the agent
"""
node_id = _action[0]
node_property = _action[1]
property_action = _action[2]
service_index = _action[3]
readable_action = self.action_dict[_action]
node_id = readable_action[0]
node_property = readable_action[1]
property_action = readable_action[2]
service_index = readable_action[3]
# Check that the action is requesting a valid node
try:
@@ -531,12 +538,15 @@ class Primaite(Env):
Args:
_action: The action space from the agent
"""
action_decision = _action[0]
action_permission = _action[1]
action_source_ip = _action[2]
action_destination_ip = _action[3]
action_protocol = _action[4]
action_port = _action[5]
# Convert discrete value back to multidiscrete
readable_action = self.action_dict[_action]
action_decision = readable_action[0]
action_permission = readable_action[1]
action_source_ip = readable_action[2]
action_destination_ip = readable_action[3]
action_protocol = readable_action[4]
action_port = readable_action[5]
if action_decision == 0:
# It's decided to do nothing
@@ -641,252 +651,20 @@ class Primaite(Env):
else:
pass
def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the BOX option chosen.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
Columns are as follows:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
:return: Box gym observation
:rtype: gym.spaces.Box
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space type BOX selected")
# 1. Determine observation shape from laydown
num_items = self.num_links + self.num_nodes
num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
observation_shape = (num_items, num_observation_parameters)
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.Box(
low=0,
high=self.OBSERVATION_SPACE_HIGH_VALUE,
shape=observation_shape,
dtype=np.int64,
)
initial_observation = np.zeros(observation_shape, dtype=np.int64)
return observation_space, initial_observation
def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the MULTIDISCRETE option chosen.
This will create the observation space with node observations followed by link observations.
Each node has 3 elements in the observation space plus 1 per service, more specifically:
* hardware state
* operating system state
* file system state
* service states (one per service)
Each link has one element in the observation space, corresponding to the traffic load,
it can take the following values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(37,)``
:return: MultiDiscrete gym observation
:rtype: gym.spaces.MultiDiscrete
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space MULTIDISCRETE selected")
# 1. Determine observation shape from laydown
node_obs_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
node_services = [len(SoftwareState) + 1] * self.num_services
node_obs_shape = node_obs_shape + node_services
# the magic number 5 refers to 5 states of quantisation of traffic amount.
# (zero, low, medium, high, fully utilised/overwhelmed)
link_obs_shape = [5] * self.num_links
observation_shape = node_obs_shape * self.num_nodes + link_obs_shape
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.MultiDiscrete(observation_shape)
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
return observation_space, initial_observation
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Build the observation space based on network laydown and provide initial obs.
"""Create the environment's observation handler.
This method uses the object's `num_links`, `num_nodes`, `num_services`,
`OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type`
attributes to figure out the correct shape and format for the observation space.
:raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType`
:return: Gym observation space
:rtype: gym.spaces.Space
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
"""
if self.observation_type == ObservationType.BOX:
observation_space, initial_observation = self._init_box_observations()
return observation_space, initial_observation
elif self.observation_type == ObservationType.MULTIDISCRETE:
(
observation_space,
initial_observation,
) = self._init_multidiscrete_observations()
return observation_space, initial_observation
else:
errmsg = (
f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}"
f", got {self.observation_type} instead"
)
_LOGGER.error(errmsg)
raise ValueError(errmsg)
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
def _update_env_obs_box(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_box_observations`
This function can only be called if the observation space setting is set to BOX.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.BOX
item_index = 0
# Do nodes first
for node_key, node in self.nodes.items():
self.env_obs[item_index][0] = int(node.node_id)
self.env_obs[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.env_obs[item_index][2] = node.software_state.value
self.env_obs[item_index][3] = node.file_system_state_observed.value
else:
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.services_list:
if node.has_service(service):
self.env_obs[item_index][
service_index
] = node.get_service_state(service).value
else:
self.env_obs[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.services_list:
self.env_obs[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for link_key, link in self.links.items():
self.env_obs[item_index][0] = int(link.get_id())
self.env_obs[item_index][1] = 0
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
def _update_env_obs_multidiscrete(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_multidiscrete_observations`
This function can only be called if the observation space setting is set to MULTIDISCRETE.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.MULTIDISCRETE
obs = []
# 1. Set nodes
# Each node has the following variables in the observation space:
# - Hardware state
# - Software state
# - File System state
# - Service 1 state
# - Service 2 state
# - ...
# - Service N state
for node_key, node in self.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
services_states = [0] * self.num_services
if isinstance(
node, ActiveNode
): # ServiceNode is a subclass of ActiveNode so no need to check that also
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.services_list):
if node.has_service(service):
services_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
software_state,
file_system_state,
*services_states,
]
)
# 2. Set links
# Each link has just one variable in the observation space, it represents the traffic amount
# In order for the space to be fully MultiDiscrete, the amount of
# traffic on each link is quantised into a few levels:
# 0: no traffic (0% of bandwidth)
# 1: low traffic (0-33% of bandwidth)
# 2: medium traffic (33-66% of bandwidth)
# 3: high traffic (66-100% of bandwidth)
# 4: max traffic/overloaded (100% of bandwidth)
for link_key, link in self.links.items():
bandwidth = link.bandwidth
load = link.get_current_load()
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = 4
else:
traffic_level = (load / bandwidth) // (1 / 3) + 1
obs.append(int(traffic_level))
self.env_obs = np.asarray(obs)
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
"""Updates the observation space based on the node and link status."""
if self.observation_type == ObservationType.BOX:
self._update_env_obs_box()
elif self.observation_type == ObservationType.MULTIDISCRETE:
self._update_env_obs_multidiscrete()
self.obs_handler.update_obs()
self.env_obs = self.obs_handler.current_observation
def load_config(self):
"""Loads config data in order to build the environment configuration."""
@@ -921,9 +699,6 @@ class Primaite(Env):
elif item["itemType"] == "ACTIONS":
# Get the action information
self.get_action_info(item)
elif item["itemType"] == "OBSERVATIONS":
# Get the observation information
self.get_observation_info(item)
elif item["itemType"] == "STEPS":
# Get the steps information
self.get_steps_info(item)
@@ -1256,13 +1031,17 @@ class Primaite(Env):
"""
self.action_type = ActionType[action_info["type"]]
def get_observation_info(self, observation_info):
"""Extracts observation_info.
def save_obs_config(self, obs_config: dict):
"""Cache the config for the observation space.
:param observation_info: Config item that defines which type of observation space to use
:type observation_info: str
This is necessary as the observation space can't be built while reading the config,
it must be done after all the nodes, links, and services have been initialised.
:param obs_config: Parsed config relating to the observation space. The format is described in
:py:meth:`primaite.environment.observations.ObservationsHandler.from_config`
:type obs_config: dict
"""
self.observation_type = ObservationType[observation_info["type"]]
self.obs_config = obs_config
def get_steps_info(self, steps_info):
"""
@@ -1347,3 +1126,91 @@ class Primaite(Env):
else:
# Bad formatting
pass
def create_node_action_dict(self):
"""
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
example return:
{0: [1, 0, 0, 0],
1: [1, 1, 1, 0],
2: [1, 1, 2, 0],
3: [1, 1, 3, 0],
4: [1, 2, 1, 0],
5: [1, 3, 1, 0],
...
}
"""
# reserve 0 action to be a nothing action
actions = {0: [1, 0, 0, 0]}
action_key = 1
for node in range(1, self.num_nodes + 1):
# 4 node properties (NONE, OPERATING, OS, SERVICE)
for node_property in range(4):
# Node Actions either:
# (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state
# Use MAX to ensure we get them all
for node_action in range(4):
for service_state in range(self.num_services):
action = [node, node_property, node_action, service_state]
# check to see if it's a nothing action (has no effect)
if is_valid_node_action(action):
actions[action_key] = action
action_key += 1
return actions
def create_acl_action_dict(self):
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0]}
action_key = 1
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
for action_decision in range(3):
# 2 possible action permissions 0 = DENY, 1 = CREATE
for action_permission in range(2):
# Number of nodes + 1 (for any)
for source_ip in range(self.num_nodes + 1):
for dest_ip in range(self.num_nodes + 1):
for protocol in range(self.num_services + 1):
for port in range(self.num_ports + 1):
action = [
action_decision,
action_permission,
source_ip,
dest_ip,
protocol,
port,
]
# Check to see if its an action we want to include as possible i.e. not a nothing action
if is_valid_acl_action_extra(action):
actions[action_key] = action
action_key += 1
return actions
def create_node_and_acl_action_dict(self):
"""
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
The dictionary contains actions of both Node and ACL action types.
"""
node_action_dict = self.create_node_action_dict()
acl_action_dict = self.create_acl_action_dict()
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict

View File

@@ -24,6 +24,7 @@ from primaite.transactions.transactions_to_file import write_transaction_to_file
def run_generic():
"""Run against a generic agent."""
for episode in range(0, config_values.num_episodes):
env.reset()
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
@@ -41,7 +42,6 @@ def run_generic():
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
env.reset()
env.close()
@@ -162,6 +162,10 @@ def load_config_values():
try:
# Generic
config_values.agent_identifier = config_data["agentIdentifier"]
if "observationSpace" in config_data:
config_values.observation_config = config_data["observationSpace"]
else:
config_values.observation_config = None
config_values.num_episodes = int(config_data["numEpisodes"])
config_values.time_delay = int(config_data["timeDelay"])
config_values.config_filename_use_case = (
@@ -376,7 +380,7 @@ logging.info("Saving transaction logs...")
write_transaction_to_file(transaction_list)
config_file_main.close
config_file_main.close()
print("Finished")
logging.info("Finished")

View File

@@ -1,68 +0,0 @@
- itemType: ACTIONS
type: NODE
- itemType: OBSERVATIONS
type: MULTIDISCRETE
- itemType: STEPS
steps: 5
- itemType: PORTS
portsList:
- port: '80'
- itemType: SERVICES
serviceList:
- name: TCP
########################################
# Nodes
- itemType: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- itemType: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- itemType: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
########################################
# Links
- itemType: LINK
id: '4'
name: link1
bandwidth: 1000
source: '1'
destination: '3'
- itemType: LINK
id: '5'
name: link2
bandwidth: 1000
source: '3'
destination: '2'

View File

@@ -1,15 +1,15 @@
- itemType: ACTIONS
type: NODE
- itemType: OBSERVATIONS
type: BOX
- itemType: STEPS
steps: 5
- itemType: PORTS
portsList:
- port: '80'
- port: '53'
- itemType: SERVICES
serviceList:
- name: TCP
- name: UDP
########################################
# Nodes
@@ -21,12 +21,15 @@
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: GOOD
software_state: COMPROMISED
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE
node_id: '2'
name: SERVER
@@ -41,6 +44,9 @@
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: OVERWHELMED
- itemType: NODE
node_id: '3'
name: SWITCH1
@@ -66,3 +72,33 @@
bandwidth: 1000
source: '3'
destination: '2'
#########################################
# IERS
- itemType: GREEN_IER
id: '5'
startStep: 0
endStep: 5
load: 999
protocol: TCP
port: '80'
source: '1'
destination: '2'
missionCriticality: 5
#########################################
# ACL Rules
- itemType: ACL_RULE
id: '6'
permission: ALLOW
source: 192.168.1.1
destination: 192.168.1.2
protocol: TCP
port: 80
- itemType: ACL_RULE
id: '7'
permission: ALLOW
source: 192.168.1.2
destination: 192.168.1.1
protocol: TCP
port: 80

View File

@@ -0,0 +1,96 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: NONE
# Number of episodes to run per session
observationSpace:
components:
- name: LINK_TRAFFIC_LEVELS
options:
combine_service_traffic: false
quantisation_levels: 8
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: one_node_states_on_off_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1_000_000_000
# Reward values
# Generic
allOk: 0
# Node Hardware State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node Software or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,93 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: NONE
# Number of episodes to run per session
observationSpace:
components:
- name: NODE_LINK_TABLE
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: one_node_states_on_off_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1_000_000_000
# Reward values
# Generic
allOk: 0
# Node Hardware State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node Software or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,93 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: NONE
# Number of episodes to run per session
observationSpace:
components:
- name: NODE_STATUSES
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: one_node_states_on_off_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1_000_000_000
# Reward values
# Generic
allOk: 0
# Node Hardware State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node Software or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,89 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: NONE
# Number of episodes to run per session
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: one_node_states_on_off_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1_000_000_000
# Reward values
# Generic
allOk: 0
# Node Hardware State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node Software or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,89 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: GENERIC
# Number of episodes to run per session
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: single_action_space_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1000000000
# Reward values
# Generic
allOk: 0
# Node Operating State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node O/S or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -0,0 +1,55 @@
- itemType: ACTIONS
type: ANY
- itemType: STEPS
steps: 15
- itemType: PORTS
portsList:
- port: '21'
- itemType: SERVICES
serviceList:
- name: ftp
- itemType: NODE
node_id: '1'
name: node
node_class: SERVICE
node_type: COMPUTER
priority: P1
hardware_state: 'ON'
ip_address: 192.168.0.14
software_state: GOOD
file_system_state: GOOD
services:
- name: ftp
port: '21'
state: COMPROMISED
- itemType: NODE
node_id: '2'
name: server_1
node_class: SERVICE
node_type: SERVER
priority: P1
hardware_state: 'ON'
ip_address: 192.168.0.1
software_state: GOOD
file_system_state: GOOD
services:
- name: ftp
port: '21'
state: COMPROMISED
- itemType: POSITION
positions:
- node: '1'
x_pos: 309
y_pos: 78
- node: '2'
x_pos: 200
y_pos: 78
- itemType: RED_IER
id: '3'
startStep: 2
endStep: 15
load: 1000
protocol: ftp
port: CORRUPT
source: '1'
destination: '2'

View File

@@ -0,0 +1,89 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: GENERIC
# Number of episodes to run per session
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: single_action_space_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1000000000
# Reward values
# Generic
allOk: 0
# Node Operating State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node O/S or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -19,6 +19,10 @@ def _get_primaite_env_from_config(
def load_config_values():
config_values.agent_identifier = config_data["agentIdentifier"]
if "observationSpace" in config_data:
config_values.observation_config = config_data["observationSpace"]
else:
config_values.observation_config = None
config_values.num_episodes = int(config_data["numEpisodes"])
config_values.time_delay = int(config_data["timeDelay"])
config_values.config_filename_use_case = lay_down_config_path
@@ -164,12 +168,13 @@ def _get_primaite_env_from_config(
# Load in config data
load_config_values()
env = Primaite(config_values, [])
# Get the number of steps (which is stored in the child config file)
config_values.num_steps = env.episode_steps
if env.config_values.agent_identifier == "GENERIC":
run_generic(env, config_values)
return env
return env, config_values
def run_generic(env, config_values):
@@ -181,7 +186,8 @@ def run_generic(env, config_values):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = env.action_space.sample()
# action = env.action_space.sample()
action = 0
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)

View File

@@ -1,34 +1,220 @@
"""Test env creation and behaviour with different observation spaces."""
import numpy as np
import pytest
from primaite.environment.observations import (
NodeLinkTable,
NodeStatuses,
ObservationsHandler,
)
from primaite.environment.primaite_env import Primaite
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def test_creating_env_with_box_obs():
"""Try creating env with box observation space."""
env = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml",
@pytest.fixture
def env(request):
"""Build Primaite environment for integration tests of observation space."""
marker = request.node.get_closest_marker("env_config_paths")
main_config_path = marker.args[0]["main_config_path"]
lay_down_config_path = marker.args[0]["lay_down_config_path"]
env, _ = _get_primaite_env_from_config(
main_config_path=main_config_path,
lay_down_config_path=lay_down_config_path,
)
yield env
@pytest.mark.env_config_paths(
dict(
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
def test_default_obs_space(env: Primaite):
"""Create environment with no obs space defined in config and check that the default obs space was created."""
env.update_environent_obs()
# we have three nodes and two links, with one service
# therefore the box observation space will have:
# * 5 columns (four fixed and one for the service)
# * 5 rows (3 nodes + 2 links)
assert env.env_obs.shape == (5, 5)
components = env.obs_handler.registered_obs_components
assert len(components) == 1
assert isinstance(components[0], NodeLinkTable)
def test_creating_env_with_multidiscrete_obs():
"""Try creating env with MultiDiscrete observation space."""
env = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "multidiscrete_obs_space_laydown_config.yaml",
@pytest.mark.env_config_paths(
dict(
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_without_obs.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
env.update_environent_obs()
)
def test_registering_components(env: Primaite):
"""Test regitering and deregistering a component."""
handler = ObservationsHandler()
component = NodeStatuses(env)
handler.register(component)
assert component in handler.registered_obs_components
handler.deregister(component)
assert component not in handler.registered_obs_components
# we have three nodes and two links, with one service
# the nodes have hardware, OS, FS, and service, the links just have bandwidth,
# therefore we need 3*4 + 2 observations
assert env.env_obs.shape == (3 * 4 + 2,)
@pytest.mark.env_config_paths(
dict(
main_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_NODE_LINK_TABLE.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
class TestNodeLinkTable:
"""Test the NodeLinkTable observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
"""Try creating env with box observation space."""
env.update_environent_obs()
# we have three nodes and two links, with two service
# therefore the box observation space will have:
# * 5 rows (3 nodes + 2 links)
# * 6 columns (four fixed and two for the services)
assert env.env_obs.shape == (5, 6)
def test_value(self, env: Primaite):
"""Test that the observation is generated correctly.
The laydown has:
* 3 nodes (2 service nodes and 1 active node)
* 2 services
* 2 links
Both nodes have both services, and all states are GOOD, therefore the expected observation value is:
* Node 1:
* 1 (id)
* 1 (good hardware state)
* 3 (compromised OS state)
* 1 (good file system state)
* 1 (good TCP state)
* 1 (good UDP state)
* Node 2:
* 2 (id)
* 1 (good hardware state)
* 1 (good OS state)
* 1 (good file system state)
* 1 (good TCP state)
* 4 (overwhelmed UDP state)
* Node 3 (active node):
* 3 (id)
* 1 (good hardware state)
* 1 (good OS state)
* 1 (good file system state)
* 0 (doesn't have service1)
* 0 (doesn't have service2)
* Link 1:
* 4 (id)
* 0 (n/a hardware state)
* 0 (n/a OS state)
* 0 (n/a file system state)
* 999 (999 traffic for service1)
* 0 (no traffic for service2)
* Link 2:
* 5 (id)
* 0 (good hardware state)
* 0 (good OS state)
* 0 (good file system state)
* 999 (999 traffic service1)
* 0 (no traffic for service2)
"""
# act = np.asarray([0,])
obs, reward, done, info = env.step(0) # apply the 'do nothing' action
assert np.array_equal(
obs,
[
[1, 1, 3, 1, 1, 1],
[2, 1, 1, 1, 1, 4],
[3, 1, 1, 1, 0, 0],
[4, 0, 0, 0, 999, 0],
[5, 0, 0, 0, 999, 0],
],
)
@pytest.mark.env_config_paths(
dict(
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_NODE_STATUSES.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
class TestNodeStatuses:
"""Test the NodeStatuses observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
"""Try creating env with NodeStatuses as the only component."""
assert env.env_obs.shape == (15,)
def test_values(self, env: Primaite):
"""Test that the hardware and software states are encoded correctly.
The laydown has:
* one node with a compromised operating system state
* one node with two services, and the second service is overwhelmed.
* all other states are good or null
Therefore, the expected state is:
* node 1:
* hardware = good (1)
* OS = compromised (3)
* file system = good (1)
* service 1 = good (1)
* service 2 = good (1)
* node 2:
* hardware = good (1)
* OS = good (1)
* file system = good (1)
* service 1 = good (1)
* service 2 = overwhelmed (4)
* node 3 (switch):
* hardware = good (1)
* OS = good (1)
* file system = good (1)
* service 1 = n/a (0)
* service 2 = n/a (0)
"""
obs, _, _, _ = env.step(0) # apply the 'do nothing' action
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
@pytest.mark.env_config_paths(
dict(
main_config_path=TEST_CONFIG_ROOT
/ "obs_tests/main_config_LINK_TRAFFIC_LEVELS.yaml",
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
)
)
class TestLinkTrafficLevels:
"""Test the LinkTrafficLevels observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
"""Try creating env with MultiDiscrete observation space."""
env.update_environent_obs()
# we have two links and two services, so the shape should be 2 * 2
assert env.env_obs.shape == (2 * 2,)
def test_values(self, env: Primaite):
"""Test that traffic values are encoded correctly.
The laydown has:
* two services
* three nodes
* two links
* an IER trying to send 999 bits of data over both links the whole time (via the first service)
* link bandwidth of 1000, therefore the utilisation is 99.9%
"""
obs, reward, done, info = env.step(0)
obs, reward, done, info = env.step(0)
# the observation space has combine_service_traffic set to False, so the space has this format:
# [link1_service1, link1_service2, link2_service1, link2_service2]
# we send 999 bits of data via link1 and link2 on service 1.
# therefore the first and third elements should be 6 and all others 0
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
assert np.array_equal(obs, [6, 0, 6, 0])

View File

@@ -8,7 +8,7 @@ def test_rewards_are_being_penalised_at_each_step_function():
When the initial state is OFF compared to reference state which is ON.
"""
env = _get_primaite_env_from_config(
env, config_values = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "one_node_states_on_off_lay_down_config.yaml",

View File

@@ -0,0 +1,100 @@
import time
from primaite.common.enums import HardwareState
from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config
def run_generic_set_actions(env, config_values):
"""Run against a generic agent with specified blue agent actions."""
# Reset the environment at the start of the episode
# env.reset()
for episode in range(0, config_values.num_episodes):
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
# action = env.blue_agent_action(obs)
action = 0
print("Episode:", episode, "\nStep:", step)
if step == 5:
# [1, 1, 2, 1, 1, 1]
# Creates an ACL rule
# Allows traffic from server_1 to node_1 on port FTP
action = 7
elif step == 7:
# [1, 1, 2, 0] Node Action
# Sets Node 1 Hardware State to OFF
# Does not resolve any service
action = 16
# Run the simulation step on the live environment
obs, reward, done, info = env.step(action)
# Break if done is True
if done:
break
# Introduce a delay between steps
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
# env.reset()
# env.close()
def test_single_action_space_is_valid():
"""Test to ensure the blue agent is using the ACL action space and is carrying out both kinds of operations."""
env, config_values = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT / "single_action_space_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
run_generic_set_actions(env, config_values)
# Retrieve the action space dictionary values from environment
env_action_space_dict = env.action_dict.values()
# Flags to check the conditions of the action space
contains_acl_actions = False
contains_node_actions = False
both_action_spaces = False
# Loop through each element of the list (which is every value from the dictionary)
for dict_item in env_action_space_dict:
# Node action detected
if len(dict_item) == 4:
contains_node_actions = True
# Link action detected
elif len(dict_item) == 6:
contains_acl_actions = True
# If both are there then the ANY action type is working
if contains_node_actions and contains_acl_actions:
both_action_spaces = True
# Check condition should be True
assert both_action_spaces
def test_agent_is_executing_actions_from_both_spaces():
"""Test to ensure the blue agent is carrying out both kinds of operations (NODE & ACL)."""
env, config_values = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT
/ "single_action_space_fixed_blue_actions_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "single_action_space_lay_down_config.yaml",
)
# Run environment with specified fixed blue agent actions only
run_generic_set_actions(env, config_values)
# Retrieve hardware state of computer_1 node in laydown config
# Agent turned this off in Step 5
computer_node_hardware_state = env.nodes["1"].hardware_state
# Retrieve the Access Control List object stored by the environment at the end of the episode
access_control_list = env.acl
# Use the Access Control List object acl object attribute to get dictionary
# Use dictionary.values() to get total list of all items in the dictionary
acl_rules_list = access_control_list.acl.values()
# Length of this list tells you how many items are in the dictionary
# This number is the frequency of Access Control Rules in the environment
# In the scenario, we specified that the agent should create only 1 acl rule
num_of_rules = len(acl_rules_list)
# Therefore these statements below MUST be true
assert computer_node_hardware_state == HardwareState.OFF
assert num_of_rules == 1