512 lines
19 KiB
Python
512 lines
19 KiB
Python
"""Module for handling configurable observation spaces in PrimAITE."""
|
|
import logging
|
|
from abc import ABC, abstractmethod
|
|
from typing import TYPE_CHECKING, Dict, Final, List, Tuple, Union
|
|
|
|
import numpy as np
|
|
from gym import spaces
|
|
|
|
from primaite.common.enums import (
|
|
FileSystemState,
|
|
HardwareState,
|
|
RulePermissionType,
|
|
SoftwareState,
|
|
)
|
|
from primaite.nodes.active_node import ActiveNode
|
|
from primaite.nodes.service_node import ServiceNode
|
|
|
|
# This dependency is only needed for type hints,
|
|
# TYPE_CHECKING is False at runtime and True when typecheckers are performing typechecking
|
|
# Therefore, this avoids circular dependency problem.
|
|
if TYPE_CHECKING:
|
|
from primaite.environment.primaite_env import Primaite
|
|
|
|
|
|
_LOGGER = logging.getLogger(__name__)
|
|
|
|
|
|
class AbstractObservationComponent(ABC):
|
|
"""Represents a part of the PrimAITE observation space."""
|
|
|
|
@abstractmethod
|
|
def __init__(self, env: "Primaite"):
|
|
_LOGGER.info(f"Initialising {self} observation component")
|
|
self.env: "Primaite" = env
|
|
self.space: spaces.Space
|
|
self.current_observation: np.ndarray # type might be too restrictive?
|
|
return NotImplemented
|
|
|
|
@abstractmethod
|
|
def update(self):
|
|
"""Update the observation based on the current state of the environment."""
|
|
self.current_observation = NotImplemented
|
|
|
|
|
|
class NodeLinkTable(AbstractObservationComponent):
|
|
"""Table with nodes and links as rows and hardware/software status as cols.
|
|
|
|
This will create the observation space formatted as a table of integers.
|
|
There is one row per node, followed by one row per link.
|
|
The number of columns is 4 plus one per service. They are:
|
|
* node/link ID
|
|
* node hardware status / 0 for links
|
|
* node operating system status (if active/service) / 0 for links
|
|
* node file system status (active/service only) / 0 for links
|
|
* node service1 status / traffic load from that service for links
|
|
* node service2 status / traffic load from that service for links
|
|
* ...
|
|
* node serviceN status / traffic load from that service for links
|
|
|
|
For example if the environment has 5 nodes, 7 links, and 3 services, the observation space shape will be
|
|
``(12, 7)``
|
|
"""
|
|
|
|
_FIXED_PARAMETERS: int = 4
|
|
_MAX_VAL: int = 1_000_000
|
|
_DATA_TYPE: type = np.int64
|
|
|
|
def __init__(self, env: "Primaite"):
|
|
super().__init__(env)
|
|
|
|
# 1. Define the shape of your observation space component
|
|
num_items = self.env.num_links + self.env.num_nodes
|
|
num_columns = self.env.num_services + self._FIXED_PARAMETERS
|
|
observation_shape = (num_items, num_columns)
|
|
|
|
# 2. Create Observation space
|
|
self.space = spaces.Box(
|
|
low=0,
|
|
high=self._MAX_VAL,
|
|
shape=observation_shape,
|
|
dtype=self._DATA_TYPE,
|
|
)
|
|
|
|
# 3. Initialise Observation with zeroes
|
|
self.current_observation = np.zeros(observation_shape, dtype=self._DATA_TYPE)
|
|
|
|
def update(self):
|
|
"""Update the observation based on current environment state.
|
|
|
|
The structure of the observation space is described in :class:`.NodeLinkTable`
|
|
"""
|
|
item_index = 0
|
|
nodes = self.env.nodes
|
|
links = self.env.links
|
|
# Do nodes first
|
|
for _, node in nodes.items():
|
|
self.current_observation[item_index][0] = int(node.node_id)
|
|
self.current_observation[item_index][1] = node.hardware_state.value
|
|
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
|
self.current_observation[item_index][2] = node.software_state.value
|
|
self.current_observation[item_index][
|
|
3
|
|
] = node.file_system_state_observed.value
|
|
else:
|
|
self.current_observation[item_index][2] = 0
|
|
self.current_observation[item_index][3] = 0
|
|
service_index = 4
|
|
if isinstance(node, ServiceNode):
|
|
for service in self.env.services_list:
|
|
if node.has_service(service):
|
|
self.current_observation[item_index][
|
|
service_index
|
|
] = node.get_service_state(service).value
|
|
else:
|
|
self.current_observation[item_index][service_index] = 0
|
|
service_index += 1
|
|
else:
|
|
# Not a service node
|
|
for service in self.env.services_list:
|
|
self.current_observation[item_index][service_index] = 0
|
|
service_index += 1
|
|
item_index += 1
|
|
|
|
# Now do links
|
|
for _, link in links.items():
|
|
self.current_observation[item_index][0] = int(link.get_id())
|
|
self.current_observation[item_index][1] = 0
|
|
self.current_observation[item_index][2] = 0
|
|
self.current_observation[item_index][3] = 0
|
|
protocol_list = link.get_protocol_list()
|
|
protocol_index = 0
|
|
for protocol in protocol_list:
|
|
self.current_observation[item_index][
|
|
protocol_index + 4
|
|
] = protocol.get_load()
|
|
protocol_index += 1
|
|
item_index += 1
|
|
|
|
|
|
class NodeStatuses(AbstractObservationComponent):
|
|
"""Flat list of nodes' hardware, OS, file system, and service states.
|
|
|
|
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
|
integers.
|
|
Each node has 3 elements plus 1 per service. It will have the following structure:
|
|
.. code-block::
|
|
[
|
|
node1 hardware state,
|
|
node1 OS state,
|
|
node1 file system state,
|
|
node1 service1 state,
|
|
node1 service2 state,
|
|
node1 serviceN state (one for each service),
|
|
node2 hardware state,
|
|
node2 OS state,
|
|
node2 file system state,
|
|
node2 service1 state,
|
|
node2 service2 state,
|
|
node2 serviceN state (one for each service),
|
|
...
|
|
]
|
|
|
|
:param env: The environment that forms the basis of the observations
|
|
:type env: Primaite
|
|
"""
|
|
|
|
_DATA_TYPE: type = np.int64
|
|
|
|
def __init__(self, env: "Primaite"):
|
|
super().__init__(env)
|
|
|
|
# 1. Define the shape of your observation space component
|
|
node_shape = [
|
|
len(HardwareState) + 1,
|
|
len(SoftwareState) + 1,
|
|
len(FileSystemState) + 1,
|
|
]
|
|
services_shape = [len(SoftwareState) + 1] * self.env.num_services
|
|
node_shape = node_shape + services_shape
|
|
|
|
shape = node_shape * self.env.num_nodes
|
|
# 2. Create Observation space
|
|
self.space = spaces.MultiDiscrete(shape)
|
|
|
|
# 3. Initialise observation with zeroes
|
|
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
|
|
|
def update(self):
|
|
"""Update the observation based on current environment state.
|
|
|
|
The structure of the observation space is described in :class:`.NodeStatuses`
|
|
"""
|
|
obs = []
|
|
for _, node in self.env.nodes.items():
|
|
hardware_state = node.hardware_state.value
|
|
software_state = 0
|
|
file_system_state = 0
|
|
service_states = [0] * self.env.num_services
|
|
|
|
if isinstance(node, ActiveNode):
|
|
software_state = node.software_state.value
|
|
file_system_state = node.file_system_state_observed.value
|
|
|
|
if isinstance(node, ServiceNode):
|
|
for i, service in enumerate(self.env.services_list):
|
|
if node.has_service(service):
|
|
service_states[i] = node.get_service_state(service).value
|
|
obs.extend(
|
|
[hardware_state, software_state, file_system_state, *service_states]
|
|
)
|
|
self.current_observation[:] = obs
|
|
|
|
|
|
class LinkTrafficLevels(AbstractObservationComponent):
|
|
"""Flat list of traffic levels encoded into banded categories.
|
|
|
|
For each link, total traffic or traffic per service is encoded into a categorical value.
|
|
For example, if ``quantisation_levels=5``, the traffic levels represent these values:
|
|
0 = No traffic (0% of bandwidth)
|
|
1 = No traffic (0%-33% of bandwidth)
|
|
2 = No traffic (33%-66% of bandwidth)
|
|
3 = No traffic (66%-100% of bandwidth)
|
|
4 = No traffic (100% of bandwidth)
|
|
|
|
.. note::
|
|
The lowest category always corresponds to no traffic and the highest category to the link being at max capacity.
|
|
Any amount of traffic between 0% and 100% (exclusive) is divided evenly into the remaining categories.
|
|
|
|
:param env: The environment that forms the basis of the observations
|
|
:type env: Primaite
|
|
:param combine_service_traffic: Whether to consider total traffic on the link, or each protocol individually,
|
|
defaults to False
|
|
:type combine_service_traffic: bool, optional
|
|
:param quantisation_levels: How many bands to consider when converting the traffic amount to a categorical value ,
|
|
defaults to 5
|
|
:type quantisation_levels: int, optional
|
|
"""
|
|
|
|
_DATA_TYPE: type = np.int64
|
|
|
|
def __init__(
|
|
self,
|
|
env: "Primaite",
|
|
combine_service_traffic: bool = False,
|
|
quantisation_levels: int = 5,
|
|
):
|
|
if quantisation_levels < 3:
|
|
_msg = (
|
|
f"quantisation_levels must be 3 or more because the lowest and highest levels are "
|
|
f"reserved for 0% and 100% link utilisation, got {quantisation_levels} instead. "
|
|
f"Resetting to default value (5)"
|
|
)
|
|
_LOGGER.warning(_msg)
|
|
quantisation_levels = 5
|
|
|
|
super().__init__(env)
|
|
|
|
self._combine_service_traffic: bool = combine_service_traffic
|
|
self._quantisation_levels: int = quantisation_levels
|
|
self._entries_per_link: int = 1
|
|
|
|
if not self._combine_service_traffic:
|
|
self._entries_per_link = self.env.num_services
|
|
|
|
# 1. Define the shape of your observation space component
|
|
shape = (
|
|
[self._quantisation_levels] * self.env.num_links * self._entries_per_link
|
|
)
|
|
|
|
# 2. Create Observation space
|
|
self.space = spaces.MultiDiscrete(shape)
|
|
|
|
# 3. Initialise observation with zeroes
|
|
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
|
|
|
def update(self):
|
|
"""Update the observation based on current environment state.
|
|
|
|
The structure of the observation space is described in :class:`.LinkTrafficLevels`
|
|
"""
|
|
obs = []
|
|
for _, link in self.env.links.items():
|
|
bandwidth = link.bandwidth
|
|
if self._combine_service_traffic:
|
|
loads = [link.get_current_load()]
|
|
else:
|
|
loads = [protocol.get_load() for protocol in link.protocol_list]
|
|
|
|
for load in loads:
|
|
if load <= 0:
|
|
traffic_level = 0
|
|
elif load >= bandwidth:
|
|
traffic_level = self._quantisation_levels - 1
|
|
else:
|
|
traffic_level = (load / bandwidth) // (
|
|
1 / (self._quantisation_levels - 2)
|
|
) + 1
|
|
|
|
obs.append(int(traffic_level))
|
|
|
|
self.current_observation[:] = obs
|
|
|
|
|
|
class AccessControlList(AbstractObservationComponent):
|
|
"""Flat list of all the Access Control Rules in the Access Control List.
|
|
|
|
The MultiDiscrete observation space can be though of as a one-dimensional vector of discrete states, represented by
|
|
integers.
|
|
|
|
:param env: The environment that forms the basis of the observations
|
|
:type env: Primaite
|
|
:param acl_implicit_rule: Whether to have an implicit DENY or implicit ALLOW ACL rule at the end of the ACL list
|
|
Default is 0 DENY, 1 ALLOW
|
|
:type acl_implicit_rule: ImplicitFirewallRule Enumeration (ALLOW or DENY)
|
|
:param max_acl_rules: Maximum number of ACLs allowed in the environment
|
|
:type max_acl_rules: int
|
|
|
|
Each ACL Rule has 6 elements. It will have the following structure:
|
|
.. code-block::
|
|
[
|
|
acl_rule1 permission,
|
|
acl_rule1 source_ip,
|
|
acl_rule1 dest_ip,
|
|
acl_rule1 protocol,
|
|
acl_rule1 port,
|
|
acl_rule1 position,
|
|
acl_rule2 permission,
|
|
acl_rule2 source_ip,
|
|
acl_rule2 dest_ip,
|
|
acl_rule2 protocol,
|
|
acl_rule2 port,
|
|
acl_rule2 position,
|
|
...
|
|
]
|
|
"""
|
|
|
|
_DATA_TYPE: type = np.int64
|
|
|
|
def __init__(self, env: "Primaite"):
|
|
super().__init__(env)
|
|
|
|
# 1. Define the shape of your observation space component
|
|
acl_shape = [
|
|
len(RulePermissionType),
|
|
len(env.nodes),
|
|
len(env.nodes),
|
|
len(env.services_list),
|
|
len(env.ports_list),
|
|
]
|
|
shape = acl_shape * self.env.max_acl_rules
|
|
|
|
# 2. Create Observation space
|
|
self.space = spaces.MultiDiscrete(shape)
|
|
|
|
# 3. Initialise observation with zeroes
|
|
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
|
|
|
# Dictionary to map services to numbers for obs space
|
|
self.services_dict = {}
|
|
|
|
def update(self):
|
|
"""Update the observation based on current environment state.
|
|
|
|
The structure of the observation space is described in :class:`.AccessControlList`
|
|
"""
|
|
obs = []
|
|
for acl_rule in self.env.acl.acl:
|
|
permission = acl_rule.permission
|
|
source_ip = acl_rule.source_ip
|
|
dest_ip = acl_rule.dest_ip
|
|
protocol = acl_rule.protocol
|
|
port = acl_rule.port
|
|
position = self.env.acl.acl.index(acl_rule)
|
|
if permission == "DENY":
|
|
permission_int = 0
|
|
else:
|
|
permission_int = 1
|
|
|
|
if source_ip == "ANY":
|
|
source_ip = 0
|
|
if dest_ip == "ANY":
|
|
dest_ip = 0
|
|
if port == "ANY":
|
|
port = 0
|
|
if protocol == "ANY":
|
|
protocol_int = 0
|
|
else:
|
|
while True:
|
|
if protocol in self.service_dict:
|
|
protocol_int = self.services_dict[protocol]
|
|
break
|
|
else:
|
|
self.services_dict[protocol] = len(self.services_dict) + 1
|
|
continue
|
|
# [0 - DENY, 1 - ALLOW] Permission
|
|
# [0 - ANY, x - IP Address/Protocol/Port]
|
|
|
|
print(permission_int, source_ip, dest_ip, protocol_int, port)
|
|
obs.extend(
|
|
[permission_int, source_ip, dest_ip, protocol_int, port, position]
|
|
)
|
|
|
|
self.current_observation[:] = obs
|
|
|
|
|
|
class ObservationsHandler:
|
|
"""Component-based observation space handler.
|
|
|
|
This allows users to configure observation spaces by mixing and matching components.
|
|
Each component can also define further parameters to make them more flexible.
|
|
"""
|
|
|
|
_REGISTRY: Final[Dict[str, type]] = {
|
|
"NODE_LINK_TABLE": NodeLinkTable,
|
|
"NODE_STATUSES": NodeStatuses,
|
|
"LINK_TRAFFIC_LEVELS": LinkTrafficLevels,
|
|
"ACCESS_CONTROL_LIST": AccessControlList,
|
|
}
|
|
|
|
def __init__(self):
|
|
self.registered_obs_components: List[AbstractObservationComponent] = []
|
|
self.space: spaces.Space
|
|
self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
|
|
|
|
def update_obs(self):
|
|
"""Fetch fresh information about the environment."""
|
|
current_obs = []
|
|
for obs in self.registered_obs_components:
|
|
obs.update()
|
|
current_obs.append(obs.current_observation)
|
|
|
|
# If there is only one component, don't use a tuple, just pass through that component's obs.
|
|
if len(current_obs) == 1:
|
|
self.current_observation = current_obs[0]
|
|
else:
|
|
self.current_observation = tuple(current_obs)
|
|
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
|
|
|
def register(self, obs_component: AbstractObservationComponent):
|
|
"""Add a component for this handler to track.
|
|
|
|
:param obs_component: The component to add.
|
|
:type obs_component: AbstractObservationComponent
|
|
"""
|
|
self.registered_obs_components.append(obs_component)
|
|
self.update_space()
|
|
|
|
def deregister(self, obs_component: AbstractObservationComponent):
|
|
"""Remove a component from this handler.
|
|
|
|
:param obs_component: Which component to remove. It must exist within this object's
|
|
``registered_obs_components`` attribute.
|
|
:type obs_component: AbstractObservationComponent
|
|
"""
|
|
self.registered_obs_components.remove(obs_component)
|
|
self.update_space()
|
|
|
|
def update_space(self):
|
|
"""Rebuild the handler's composite observation space from its components."""
|
|
component_spaces = []
|
|
for obs_comp in self.registered_obs_components:
|
|
component_spaces.append(obs_comp.space)
|
|
|
|
# If there is only one component, don't use a tuple space, just pass through that component's space.
|
|
if len(component_spaces) == 1:
|
|
self.space = component_spaces[0]
|
|
else:
|
|
self.space = spaces.Tuple(component_spaces)
|
|
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
|
|
|
|
@classmethod
|
|
def from_config(cls, env: "Primaite", obs_space_config: dict):
|
|
"""Parse a config dictinary, return a new observation handler populated with new observation component objects.
|
|
|
|
The expected format for the config dictionary is:
|
|
|
|
..code-block::python
|
|
config = {
|
|
components: [
|
|
{
|
|
"name": "<COMPONENT1_NAME>"
|
|
},
|
|
{
|
|
"name": "<COMPONENT2_NAME>"
|
|
"options": {"opt1": val1, "opt2": val2}
|
|
},
|
|
{
|
|
...
|
|
},
|
|
]
|
|
}
|
|
|
|
:return: Observation handler
|
|
:rtype: primaite.environment.observations.ObservationsHandler
|
|
"""
|
|
# Instantiate the handler
|
|
handler = cls()
|
|
|
|
for component_cfg in obs_space_config["components"]:
|
|
# Figure out which class can instantiate the desired component
|
|
comp_type = component_cfg["name"]
|
|
comp_class = cls._REGISTRY[comp_type]
|
|
|
|
# Create the component with options from the YAML
|
|
options = component_cfg.get("options") or {}
|
|
component = comp_class(env, **options)
|
|
|
|
handler.register(component)
|
|
|
|
handler.update_obs()
|
|
return handler
|