Merge remote-tracking branch 'origin/dev' into feature/898-Fix-the-functionality-of-resetting-a-node

This commit is contained in:
Chris McCarthy
2023-06-12 14:20:16 +01:00
23 changed files with 1758 additions and 420 deletions

View File

View File

@@ -0,0 +1,127 @@
from primaite.common.enums import NodeHardwareAction, NodePOLType, NodeSoftwareAction
def transform_action_node_readable(action):
"""
Convert a node action from enumerated format to readable format.
example:
[1, 3, 1, 0] -> [1, 'SERVICE', 'PATCHING', 0]
"""
action_node_property = NodePOLType(action[1]).name
if action_node_property == "OPERATING":
property_action = NodeHardwareAction(action[2]).name
elif (action_node_property == "OS" or action_node_property == "SERVICE") and action[
2
] <= 1:
property_action = NodeSoftwareAction(action[2]).name
else:
property_action = "NONE"
new_action = [action[0], action_node_property, property_action, action[3]]
return new_action
def transform_action_acl_readable(action):
"""
Transform an ACL action to a more readable format.
example:
[0, 1, 2, 5, 0, 1] -> ['NONE', 'ALLOW', 2, 5, 'ANY', 1]
"""
action_decisions = {0: "NONE", 1: "CREATE", 2: "DELETE"}
action_permissions = {0: "DENY", 1: "ALLOW"}
action_decision = action_decisions[action[0]]
action_permission = action_permissions[action[1]]
# For IPs, Ports and Protocols, 0 means any, otherwise its just an index
new_action = [action_decision, action_permission] + list(action[2:6])
for n, val in enumerate(list(action[2:6])):
if val == 0:
new_action[n + 2] = "ANY"
return new_action
def is_valid_node_action(action):
"""Is the node action an actual valid action.
Only uses information about the action to determine if the action has an effect
Does NOT consider:
- Node ID not valid to perform an operation - e.g. selected node has no service so cannot patch
- Node already being in that state (turning an ON node ON)
"""
action_r = transform_action_node_readable(action)
node_property = action_r[1]
node_action = action_r[2]
# print("node property", node_property, "\nnode action", node_action)
if node_property == "NONE":
return False
if node_action == "NONE":
return False
if node_property == "OPERATING" and node_action == "PATCHING":
# Operating State cannot PATCH
return False
if node_property != "OPERATING" and node_action not in ["NONE", "PATCHING"]:
# Software States can only do Nothing or Patch
return False
return True
def is_valid_acl_action(action):
"""
Is the ACL action an actual valid action.
Only uses information about the action to determine if the action has an effect.
Does NOT consider:
- Trying to create identical rules
- Trying to create a rule which is a subset of another rule (caused by "ANY")
"""
action_r = transform_action_acl_readable(action)
action_decision = action_r[0]
action_permission = action_r[1]
action_source_id = action_r[2]
action_destination_id = action_r[3]
if action_decision == "NONE":
return False
if (
action_source_id == action_destination_id
and action_source_id != "ANY"
and action_destination_id != "ANY"
):
# ACL rule towards itself
return False
if action_permission == "DENY":
# DENY is unnecessary, we can create and delete allow rules instead
# No allow rule = blocked/DENY by feault. ALLOW overrides existing DENY.
return False
return True
def is_valid_acl_action_extra(action):
"""Harsher version of valid acl actions, does not allow action."""
if is_valid_acl_action(action) is False:
return False
action_r = transform_action_acl_readable(action)
action_protocol = action_r[4]
action_port = action_r[5]
# Don't allow protocols or ports to be ANY
# in the future we might want to do the opposite, and only have ANY option for ports and service
if action_protocol == "ANY":
return False
if action_port == "ANY":
return False
return True

View File

@@ -9,6 +9,7 @@ class ConfigValuesMain(object):
"""Init."""
# Generic
self.agent_identifier = "" # the agent in use
self.observation_config = None # observation space config
self.num_episodes = 0 # number of episodes to train over
self.num_steps = 0 # number of steps in an episode
self.time_delay = 0 # delay between steps (ms) - applies to generic agents only

View File

@@ -51,6 +51,7 @@ class SoftwareState(Enum):
class NodePOLType(Enum):
"""Node Pattern of Life type enumeration."""
NONE = 0
OPERATING = 1
OS = 2
SERVICE = 3
@@ -83,6 +84,7 @@ class ActionType(Enum):
NODE = 0
ACL = 1
ANY = 2
class ObservationType(Enum):
@@ -100,3 +102,29 @@ class FileSystemState(Enum):
DESTROYED = 3
REPAIRING = 4
RESTORING = 5
class NodeHardwareAction(Enum):
"""Node hardware action."""
NONE = 0
ON = 1
OFF = 2
RESET = 3
class NodeSoftwareAction(Enum):
"""Node software action."""
NONE = 0
PATCHING = 1
class LinkStatus(Enum):
"""Link traffic status."""
NONE = 0
LOW = 1
MEDIUM = 2
HIGH = 3
OVERLOAD = 4

View File

@@ -0,0 +1,403 @@
"""Module for handling configurable observation spaces in PrimAITE."""
import logging
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
import numpy as np
from gym import spaces
from primaite.common.enums import FileSystemState, HardwareState, SoftwareState
from primaite.nodes.active_node import ActiveNode
from primaite.nodes.service_node import ServiceNode
# This dependency is only needed for type hints,
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
# Therefore, this avoids circular dependency problem.
if TYPE_CHECKING:
from primaite.environment.primaite_env import Primaite
_LOGGER = logging.getLogger(__name__)
class AbstractObservationComponent(ABC):
"""Represents a part of the PrimAITE observation space."""
@abstractmethod
def __init__(self, env: "Primaite"):
_LOGGER.info(f"Initialising {self} observation component")
self.env: "Primaite" = env
self.space: spaces.Space
self.current_observation: np.ndarray # type might be too restrictive?
return NotImplemented
@abstractmethod
def update(self):
"""Update the observation based on the current state of the environment."""
self.current_observation = NotImplemented
class NodeLinkTable(AbstractObservationComponent):
"""Table with nodes and links as rows and hardware/software status as cols.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
The number of columns is 4 plus one per service. They are:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
"""
_FIXED_PARAMETERS: int = 4
_MAX_VAL: int = 1_000_000
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
super().__init__(env)
# 1. Define the shape of your observation space component
num_items = self.env.num_links + self.env.num_nodes
num_columns = self.env.num_services + self._FIXED_PARAMETERS
observation_shape = (num_items, num_columns)
# 2. Create Observation space
self.space = spaces.Box(
low=0,
high=self._MAX_VAL,
shape=observation_shape,
dtype=self._DATA_TYPE,
)
# 3. Initialise Observation with zeroes
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeLinkTable`
"""
item_index = 0
nodes = self.env.nodes
links = self.env.links
# Do nodes first
for _, node in nodes.items():
self.current_observation[item_index][0] = int(node.node_id)
self.current_observation[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.current_observation[item_index][2] = node.software_state.value
self.current_observation[item_index][
3
] = node.file_system_state_observed.value
else:
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.env.services_list:
if node.has_service(service):
self.current_observation[item_index][
service_index
] = node.get_service_state(service).value
else:
self.current_observation[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.env.services_list:
self.current_observation[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for _, link in links.items():
self.current_observation[item_index][0] = int(link.get_id())
self.current_observation[item_index][1] = 0
self.current_observation[item_index][2] = 0
self.current_observation[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.current_observation[item_index][
protocol_index + 4
] = protocol.get_load()
protocol_index += 1
item_index += 1
class NodeStatuses(AbstractObservationComponent):
"""Flat list of nodes' hardware, OS, file system, and service states.
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
integers.
Each node has 3 elements plus 1 per service. It will have the following structure:
.. code-block::
[
node1 hardware state,
node1 OS state,
node1 file system state,
node1 service1 state,
node1 service2 state,
node1 serviceN state (one for each service),
node2 hardware state,
node2 OS state,
node2 file system state,
node2 service1 state,
node2 service2 state,
node2 serviceN state (one for each service),
...
]
:param env: The environment that forms the basis of the observations
:type env: Primaite
"""
_DATA_TYPE: type = np.int64
def __init__(self, env: "Primaite"):
super().__init__(env)
# 1. Define the shape of your observation space component
node_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
services_shape = [len(SoftwareState) + 1] * self.env.num_services
node_shape = node_shape + services_shape
shape = node_shape * self.env.num_nodes
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.NodeStatuses`
"""
obs = []
for _, node in self.env.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
service_states = [0] * self.env.num_services
if isinstance(node, ActiveNode):
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.env.services_list):
if node.has_service(service):
service_states[i] = node.get_service_state(service).value
obs.extend(
[hardware_state, software_state, file_system_state, *service_states]
)
self.current_observation[:] = obs
class LinkTrafficLevels(AbstractObservationComponent):
"""Flat list of traffic levels encoded into banded categories.
For each link, total traffic or traffic per service is encoded into a categorical value.
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
.. note::
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
:param env: The environment that forms the basis of the observations
:type env: Primaite
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
defaults to False
:type combine_service_traffic: bool, optional
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value ,
defaults to 5
:type quantisation_levels: int, optional
"""
_DATA_TYPE: type = np.int64
def __init__(
self,
env: "Primaite",
combine_service_traffic: bool = False,
quantisation_levels: int = 5,
):
if quantisation_levels < 3:
_msg = (
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
f"Resetting to default value (5)"
)
_LOGGER.warning(_msg)
quantisation_levels = 5
super().__init__(env)
self._combine_service_traffic: bool = combine_service_traffic
self._quantisation_levels: int = quantisation_levels
self._entries_per_link: int = 1
if not self._combine_service_traffic:
self._entries_per_link = self.env.num_services
# 1. Define the shape of your observation space component
shape = (
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
)
# 2. Create Observation space
self.space = spaces.MultiDiscrete(shape)
# 3. Initialise observation with zeroes
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
def update(self):
"""Update the observation based on current environment state.
The structure of the observation space is described in :class:`.LinkTrafficLevels`
"""
obs = []
for _, link in self.env.links.items():
bandwidth = link.bandwidth
if self._combine_service_traffic:
loads = [link.get_current_load()]
else:
loads = [protocol.get_load() for protocol in link.protocol_list]
for load in loads:
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = self._quantisation_levels - 1
else:
traffic_level = (load / bandwidth) // (
1 / (self._quantisation_levels - 2)
) + 1
obs.append(int(traffic_level))
self.current_observation[:] = obs
class ObservationsHandler:
"""Component-based observation space handler.
This allows users to configure observation spaces by mixing and matching components.
Each component can also define further parameters to make them more flexible.
"""
_REGISTRY: Final[Dict[str, type]] = {
"NODE_LINK_TABLE": NodeLinkTable,
"NODE_STATUSES": NodeStatuses,
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
}
def __init__(self):
self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
def update_obs(self):
"""Fetch fresh information about the environment."""
current_obs = []
for obs in self.registered_obs_components:
obs.update()
current_obs.append(obs.current_observation)
# If there is only one component, don't use a tuple, just pass through that component's obs.
if len(current_obs) == 1:
self.current_observation = current_obs[0]
else:
self.current_observation = tuple(current_obs)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track.
:param obs_component: The component to add.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.append(obs_component)
self.update_space()
def deregister(self, obs_component: AbstractObservationComponent):
"""Remove a component from this handler.
:param obs_component: Which component to remove. It must exist within this object's
``registered_obs_components`` attribute.
:type obs_component: AbstractObservationComponent
"""
self.registered_obs_components.remove(obs_component)
self.update_space()
def update_space(self):
"""Rebuild the handler's composite observation space from its components."""
component_spaces = []
for obs_comp in self.registered_obs_components:
component_spaces.append(obs_comp.space)
# If there is only one component, don't use a tuple space, just pass through that component's space.
if len(component_spaces) == 1:
self.space = component_spaces[0]
else:
self.space = spaces.Tuple(component_spaces)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
@classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict):
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
The expected format for the config dictionary is:
..code-block::python
config = {
components: [
{
"name": "<COMPONENT1_NAME>"
},
{
"name": "<COMPONENT2_NAME>"
"options": {"opt1": val1, "opt2": val2}
},
{
...
},
]
}
:return: Observation handler
:rtype: primaite.environment.observations.ObservationsHandler
"""
# Instantiate the handler
handler = cls()
for component_cfg in obs_space_config["components"]:
# Figure out which class can instantiate the desired component
comp_type = component_cfg["name"]
comp_class = cls._REGISTRY[comp_type]
# Create the component with options from the YAML
options = component_cfg.get("options") or {}
component = comp_class(env, **options)
handler.register(component)
handler.update_obs()
return handler

View File

@@ -15,6 +15,7 @@ from gym import Env, spaces
from matplotlib import pyplot as plt
from primaite.acl.access_control_list import AccessControlList
from primaite.agents.utils import is_valid_acl_action_extra, is_valid_node_action
from primaite.common.custom_typing import NodeUnion
from primaite.common.enums import (
ActionType,
@@ -23,11 +24,11 @@ from primaite.common.enums import (
NodePOLInitiator,
NodePOLType,
NodeType,
ObservationType,
Priority,
SoftwareState,
)
from primaite.common.service import Service
from primaite.environment.observations import ObservationsHandler
from primaite.environment.reward import calculate_reward_function
from primaite.links.link import Link
from primaite.nodes.active_node import ActiveNode
@@ -42,19 +43,17 @@ from primaite.pol.red_agent_pol import apply_red_agent_iers, apply_red_agent_nod
from primaite.transactions.transaction import Transaction
_LOGGER = logging.getLogger(__name__)
_LOGGER.setLevel(logging.INFO)
class Primaite(Env):
"""PRIMmary AI Training Evironment (Primaite) class."""
# Observation / Action Space contants
OBSERVATION_SPACE_FIXED_PARAMETERS = 4
ACTION_SPACE_NODE_PROPERTY_VALUES = 5
ACTION_SPACE_NODE_ACTION_VALUES = 4
ACTION_SPACE_ACL_ACTION_VALUES = 3
ACTION_SPACE_ACL_PERMISSION_VALUES = 2
OBSERVATION_SPACE_HIGH_VALUE = 1000000 # Highest value within an observation space
# Action Space contants
ACTION_SPACE_NODE_PROPERTY_VALUES: int = 5
ACTION_SPACE_NODE_ACTION_VALUES: int = 4
ACTION_SPACE_ACL_ACTION_VALUES: int = 3
ACTION_SPACE_ACL_PERMISSION_VALUES: int = 2
def __init__(self, _config_values, _transaction_list):
"""
@@ -149,8 +148,14 @@ class Primaite(Env):
# The action type
self.action_type = 0
# Observation type, by default box.
self.observation_type = ObservationType.BOX
# stores the observation config from the yaml, default is NODE_LINK_TABLE
self.obs_config: dict = {"components": [{"name": "NODE_LINK_TABLE"}]}
if self.config_values.observation_config is not None:
self.obs_config = self.config_values.observation_config
# Observation Handler manages the user-configurable observation space.
# It will be initialised later.
self.obs_handler: ObservationsHandler
# Open the config file and build the environment laydown
try:
@@ -202,15 +207,9 @@ class Primaite(Env):
# [0, 4] - what property it's acting on (0 = nothing, state, SoftwareState, service state, file system state) # noqa
# [0, 3] - action on property (0 = nothing, On / Scan, Off / Repair, Reset / Patch / Restore) # noqa
# [0, num services] - resolves to service ID (0 = nothing, resolves to service) # noqa
self.action_space = spaces.MultiDiscrete(
[
self.num_nodes,
self.ACTION_SPACE_NODE_PROPERTY_VALUES,
self.ACTION_SPACE_NODE_ACTION_VALUES,
self.num_services,
]
)
else:
self.action_dict = self.create_node_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ACL:
_LOGGER.info("Action space type ACL selected")
# Terms (for ACL action space):
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
@@ -219,17 +218,14 @@ class Primaite(Env):
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
# [0, num services] - Protocol (0 = any, then 1 -> x resolving to protocol)
# [0, num ports] - Port (0 = any, then 1 -> x resolving to port)
self.action_space = spaces.MultiDiscrete(
[
self.ACTION_SPACE_ACL_ACTION_VALUES,
self.ACTION_SPACE_ACL_PERMISSION_VALUES,
self.num_nodes + 1,
self.num_nodes + 1,
self.num_services + 1,
self.num_ports + 1,
]
)
self.action_dict = self.create_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
elif self.action_type == ActionType.ANY:
_LOGGER.info("Action space type ANY selected - Node + ACL")
self.action_dict = self.create_node_and_acl_action_dict()
self.action_space = spaces.Discrete(len(self.action_dict))
else:
_LOGGER.info("Invalid action type selected")
# Set up a csv to store the results of the training
try:
now = datetime.now() # current date and time
@@ -368,14 +364,14 @@ class Primaite(Env):
# 5. Calculate reward signal (for RL)
reward = calculate_reward_function(
self.nodes_post_pol,
self.nodes_post_blue,
self.nodes_post_red,
self.nodes_reference,
self.green_iers,
self.red_iers,
self.step_count,
self.config_values,
)
# print(f" Step {self.step_count} Reward: {str(reward)}")
print(f" Step {self.step_count} Reward: {str(reward)}")
self.total_reward += reward
if self.step_count == self.episode_steps:
self.average_reward = self.total_reward / self.step_count
@@ -432,8 +428,18 @@ class Primaite(Env):
# At the moment, actions are only affecting nodes
if self.action_type == ActionType.NODE:
self.apply_actions_to_nodes(_action)
else:
elif self.action_type == ActionType.ACL:
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 6
): # ACL actions in multidiscrete form have len 6
self.apply_actions_to_acl(_action)
elif (
len(self.action_dict[_action]) == 4
): # Node actions in multdiscrete (array) from have len 4
self.apply_actions_to_nodes(_action)
else:
logging.error("Invalid action type found")
def apply_actions_to_nodes(self, _action):
"""
@@ -442,10 +448,11 @@ class Primaite(Env):
Args:
_action: The action space from the agent
"""
node_id = _action[0]
node_property = _action[1]
property_action = _action[2]
service_index = _action[3]
readable_action = self.action_dict[_action]
node_id = readable_action[0]
node_property = readable_action[1]
property_action = readable_action[2]
service_index = readable_action[3]
# Check that the action is requesting a valid node
try:
@@ -531,12 +538,15 @@ class Primaite(Env):
Args:
_action: The action space from the agent
"""
action_decision = _action[0]
action_permission = _action[1]
action_source_ip = _action[2]
action_destination_ip = _action[3]
action_protocol = _action[4]
action_port = _action[5]
# Convert discrete value back to multidiscrete
readable_action = self.action_dict[_action]
action_decision = readable_action[0]
action_permission = readable_action[1]
action_source_ip = readable_action[2]
action_destination_ip = readable_action[3]
action_protocol = readable_action[4]
action_port = readable_action[5]
if action_decision == 0:
# It's decided to do nothing
@@ -641,252 +651,20 @@ class Primaite(Env):
else:
pass
def _init_box_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the BOX option chosen.
This will create the observation space formatted as a table of integers.
There is one row per node, followed by one row per link.
Columns are as follows:
* node/link ID
* node hardware status / 0 for links
* node operating system status (if active/service) / 0 for links
* node file system status (active/service only) / 0 for links
* node service1 status / traffic load from that service for links
* node service2 status / traffic load from that service for links
* ...
* node serviceN status / traffic load from that service for links
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(12, 7)``
:return: Box gym observation
:rtype: gym.spaces.Box
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space type BOX selected")
# 1. Determine observation shape from laydown
num_items = self.num_links + self.num_nodes
num_observation_parameters = (
self.num_services + self.OBSERVATION_SPACE_FIXED_PARAMETERS
)
observation_shape = (num_items, num_observation_parameters)
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.Box(
low=0,
high=self.OBSERVATION_SPACE_HIGH_VALUE,
shape=observation_shape,
dtype=np.int64,
)
initial_observation = np.zeros(observation_shape, dtype=np.int64)
return observation_space, initial_observation
def _init_multidiscrete_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Initialise the observation space with the MULTIDISCRETE option chosen.
This will create the observation space with node observations followed by link observations.
Each node has 3 elements in the observation space plus 1 per service, more specifically:
* hardware state
* operating system state
* file system state
* service states (one per service)
Each link has one element in the observation space, corresponding to the traffic load,
it can take the following values:
0 = No traffic (0% of bandwidth)
1 = No traffic (0%-33% of bandwidth)
2 = No traffic (33%-66% of bandwidth)
3 = No traffic (66%-100% of bandwidth)
4 = No traffic (100% of bandwidth)
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
``(37,)``
:return: MultiDiscrete gym observation
:rtype: gym.spaces.MultiDiscrete
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
"""
_LOGGER.info("Observation space MULTIDISCRETE selected")
# 1. Determine observation shape from laydown
node_obs_shape = [
len(HardwareState) + 1,
len(SoftwareState) + 1,
len(FileSystemState) + 1,
]
node_services = [len(SoftwareState) + 1] * self.num_services
node_obs_shape = node_obs_shape + node_services
# the magic number 5 refers to 5 states of quantisation of traffic amount.
# (zero, low, medium, high, fully utilised/overwhelmed)
link_obs_shape = [5] * self.num_links
observation_shape = node_obs_shape * self.num_nodes + link_obs_shape
# 2. Create observation space & zeroed out sample from space.
observation_space = spaces.MultiDiscrete(observation_shape)
initial_observation = np.zeros(len(observation_shape), dtype=np.int64)
return observation_space, initial_observation
def init_observations(self) -> Tuple[spaces.Space, np.ndarray]:
"""Build the observation space based on network laydown and provide initial obs.
"""Create the environment's observation handler.
This method uses the object's `num_links`, `num_nodes`, `num_services`,
`OBSERVATION_SPACE_FIXED_PARAMETERS`, `OBSERVATION_SPACE_HIGH_VALUE`, and `observation_type`
attributes to figure out the correct shape and format for the observation space.
:raises ValueError: If the env's `observation_type` attribute is not set to a valid `enums.ObservationType`
:return: Gym observation space
:rtype: gym.spaces.Space
:return: Initial observation with all entires set to 0
:rtype: numpy.Array
:return: The observation space, initial observation (zeroed out array with the correct shape)
:rtype: Tuple[spaces.Space, np.ndarray]
"""
if self.observation_type == ObservationType.BOX:
observation_space, initial_observation = self._init_box_observations()
return observation_space, initial_observation
elif self.observation_type == ObservationType.MULTIDISCRETE:
(
observation_space,
initial_observation,
) = self._init_multidiscrete_observations()
return observation_space, initial_observation
else:
errmsg = (
f"Observation type must be {ObservationType.BOX} or {ObservationType.MULTIDISCRETE}"
f", got {self.observation_type} instead"
)
_LOGGER.error(errmsg)
raise ValueError(errmsg)
self.obs_handler = ObservationsHandler.from_config(self, self.obs_config)
def _update_env_obs_box(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_box_observations`
This function can only be called if the observation space setting is set to BOX.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.BOX
item_index = 0
# Do nodes first
for node_key, node in self.nodes.items():
self.env_obs[item_index][0] = int(node.node_id)
self.env_obs[item_index][1] = node.hardware_state.value
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
self.env_obs[item_index][2] = node.software_state.value
self.env_obs[item_index][3] = node.file_system_state_observed.value
else:
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
service_index = 4
if isinstance(node, ServiceNode):
for service in self.services_list:
if node.has_service(service):
self.env_obs[item_index][
service_index
] = node.get_service_state(service).value
else:
self.env_obs[item_index][service_index] = 0
service_index += 1
else:
# Not a service node
for service in self.services_list:
self.env_obs[item_index][service_index] = 0
service_index += 1
item_index += 1
# Now do links
for link_key, link in self.links.items():
self.env_obs[item_index][0] = int(link.get_id())
self.env_obs[item_index][1] = 0
self.env_obs[item_index][2] = 0
self.env_obs[item_index][3] = 0
protocol_list = link.get_protocol_list()
protocol_index = 0
for protocol in protocol_list:
self.env_obs[item_index][protocol_index + 4] = protocol.get_load()
protocol_index += 1
item_index += 1
def _update_env_obs_multidiscrete(self):
"""Update the environment's observation state based on the current status of nodes and links.
The structure of the observation space is described in :func:`~_init_multidiscrete_observations`
This function can only be called if the observation space setting is set to MULTIDISCRETE.
:raises AssertionError: If this function is called when the environment has the incorrect ``observation_type``
"""
assert self.observation_type == ObservationType.MULTIDISCRETE
obs = []
# 1. Set nodes
# Each node has the following variables in the observation space:
# - Hardware state
# - Software state
# - File System state
# - Service 1 state
# - Service 2 state
# - ...
# - Service N state
for node_key, node in self.nodes.items():
hardware_state = node.hardware_state.value
software_state = 0
file_system_state = 0
services_states = [0] * self.num_services
if isinstance(
node, ActiveNode
): # ServiceNode is a subclass of ActiveNode so no need to check that also
software_state = node.software_state.value
file_system_state = node.file_system_state_observed.value
if isinstance(node, ServiceNode):
for i, service in enumerate(self.services_list):
if node.has_service(service):
services_states[i] = node.get_service_state(service).value
obs.extend(
[
hardware_state,
software_state,
file_system_state,
*services_states,
]
)
# 2. Set links
# Each link has just one variable in the observation space, it represents the traffic amount
# In order for the space to be fully MultiDiscrete, the amount of
# traffic on each link is quantised into a few levels:
# 0: no traffic (0% of bandwidth)
# 1: low traffic (0-33% of bandwidth)
# 2: medium traffic (33-66% of bandwidth)
# 3: high traffic (66-100% of bandwidth)
# 4: max traffic/overloaded (100% of bandwidth)
for link_key, link in self.links.items():
bandwidth = link.bandwidth
load = link.get_current_load()
if load <= 0:
traffic_level = 0
elif load >= bandwidth:
traffic_level = 4
else:
traffic_level = (load / bandwidth) // (1 / 3) + 1
obs.append(int(traffic_level))
self.env_obs = np.asarray(obs)
return self.obs_handler.space, self.obs_handler.current_observation
def update_environent_obs(self):
"""Updates the observation space based on the node and link status."""
if self.observation_type == ObservationType.BOX:
self._update_env_obs_box()
elif self.observation_type == ObservationType.MULTIDISCRETE:
self._update_env_obs_multidiscrete()
self.obs_handler.update_obs()
self.env_obs = self.obs_handler.current_observation
def load_config(self):
"""Loads config data in order to build the environment configuration."""
@@ -921,9 +699,6 @@ class Primaite(Env):
elif item["itemType"] == "ACTIONS":
# Get the action information
self.get_action_info(item)
elif item["itemType"] == "OBSERVATIONS":
# Get the observation information
self.get_observation_info(item)
elif item["itemType"] == "STEPS":
# Get the steps information
self.get_steps_info(item)
@@ -1256,13 +1031,17 @@ class Primaite(Env):
"""
self.action_type = ActionType[action_info["type"]]
def get_observation_info(self, observation_info):
"""Extracts observation_info.
def save_obs_config(self, obs_config: dict):
"""Cache the config for the observation space.
:param observation_info: Config item that defines which type of observation space to use
:type observation_info: str
This is necessary as the observation space can't be built while reading the config,
it must be done after all the nodes, links, and services have been initialised.
:param obs_config: Parsed config relating to the observation space. The format is described in
:py:meth:`primaite.environment.observations.ObservationsHandler.from_config`
:type obs_config: dict
"""
self.observation_type = ObservationType[observation_info["type"]]
self.obs_config = obs_config
def get_steps_info(self, steps_info):
"""
@@ -1347,3 +1126,91 @@ class Primaite(Env):
else:
# Bad formatting
pass
def create_node_action_dict(self):
"""
Creates a dictionary mapping each possible discrete action to more readable multidiscrete action.
Note: Only actions that have the potential to change the state exist in the mapping (except for key 0)
example return:
{0: [1, 0, 0, 0],
1: [1, 1, 1, 0],
2: [1, 1, 2, 0],
3: [1, 1, 3, 0],
4: [1, 2, 1, 0],
5: [1, 3, 1, 0],
...
}
"""
# reserve 0 action to be a nothing action
actions = {0: [1, 0, 0, 0]}
action_key = 1
for node in range(1, self.num_nodes + 1):
# 4 node properties (NONE, OPERATING, OS, SERVICE)
for node_property in range(4):
# Node Actions either:
# (NONE, ON, OFF, RESET) - operating state OR (NONE, PATCH) - OS/service state
# Use MAX to ensure we get them all
for node_action in range(4):
for service_state in range(self.num_services):
action = [node, node_property, node_action, service_state]
# check to see if it's a nothing action (has no effect)
if is_valid_node_action(action):
actions[action_key] = action
action_key += 1
return actions
def create_acl_action_dict(self):
"""Creates a dictionary mapping each possible discrete action to more readable multidiscrete action."""
# reserve 0 action to be a nothing action
actions = {0: [0, 0, 0, 0, 0, 0]}
action_key = 1
# 3 possible action decisions, 0=NOTHING, 1=CREATE, 2=DELETE
for action_decision in range(3):
# 2 possible action permissions 0 = DENY, 1 = CREATE
for action_permission in range(2):
# Number of nodes + 1 (for any)
for source_ip in range(self.num_nodes + 1):
for dest_ip in range(self.num_nodes + 1):
for protocol in range(self.num_services + 1):
for port in range(self.num_ports + 1):
action = [
action_decision,
action_permission,
source_ip,
dest_ip,
protocol,
port,
]
# Check to see if its an action we want to include as possible i.e. not a nothing action
if is_valid_acl_action_extra(action):
actions[action_key] = action
action_key += 1
return actions
def create_node_and_acl_action_dict(self):
"""
Create a dictionary mapping each possible discrete action to a more readable mutlidiscrete action.
The dictionary contains actions of both Node and ACL action types.
"""
node_action_dict = self.create_node_action_dict()
acl_action_dict = self.create_acl_action_dict()
# Change node keys to not overlap with acl keys
# Only 1 nothing action (key 0) is required, remove the other
new_node_action_dict = {
k + len(acl_action_dict) - 1: v
for k, v in node_action_dict.items()
if k != 0
}
# Combine the Node dict and ACL dict
combined_action_dict = {**acl_action_dict, **new_node_action_dict}
return combined_action_dict

View File

@@ -24,6 +24,7 @@ from primaite.transactions.transactions_to_file import write_transaction_to_file
def run_generic():
"""Run against a generic agent."""
for episode in range(0, config_values.num_episodes):
env.reset()
for step in range(0, config_values.num_steps):
# Send the observation space to the agent to get an action
# TEMP - random action for now
@@ -41,7 +42,6 @@ def run_generic():
time.sleep(config_values.time_delay / 1000)
# Reset the environment at the end of the episode
env.reset()
env.close()
@@ -162,6 +162,10 @@ def load_config_values():
try:
# Generic
config_values.agent_identifier = config_data["agentIdentifier"]
if "observationSpace" in config_data:
config_values.observation_config = config_data["observationSpace"]
else:
config_values.observation_config = None
config_values.num_episodes = int(config_data["numEpisodes"])
config_values.time_delay = int(config_data["timeDelay"])
config_values.config_filename_use_case = (
@@ -376,7 +380,7 @@ logging.info("Saving transaction logs...")
write_transaction_to_file(transaction_list)
config_file_main.close
config_file_main.close()
print("Finished")
logging.info("Finished")