Add tests for observations

This commit is contained in:
Marek Wolan
2023-06-02 12:59:01 +01:00
parent 602bf9ba9a
commit 14096b3dd1
8 changed files with 476 additions and 32 deletions

View File

@@ -1,3 +1,5 @@
[pytest] [pytest]
testpaths = testpaths =
tests tests
markers =
env_config_paths

View File

@@ -165,14 +165,15 @@ class NodeStatuses(AbstractObservationComponent):
super().__init__(env) super().__init__(env)
# 1. Define the shape of your observation space component # 1. Define the shape of your observation space component
shape = [ node_shape = [
len(HardwareState) + 1, len(HardwareState) + 1,
len(SoftwareState) + 1, len(SoftwareState) + 1,
len(FileSystemState) + 1, len(FileSystemState) + 1,
] ]
services_shape = [len(SoftwareState) + 1] * self.env.num_services services_shape = [len(SoftwareState) + 1] * self.env.num_services
shape = shape + services_shape node_shape = node_shape + services_shape
shape = node_shape * self.env.num_nodes
# 2. Create Observation space # 2. Create Observation space
self.space = spaces.MultiDiscrete(shape) self.space = spaces.MultiDiscrete(shape)
@@ -199,7 +200,9 @@ class NodeStatuses(AbstractObservationComponent):
for i, service in enumerate(self.env.services_list): for i, service in enumerate(self.env.services_list):
if node.has_service(service): if node.has_service(service):
service_states[i] = node.get_service_state(service).value service_states[i] = node.get_service_state(service).value
obs.extend([hardware_state, software_state, file_system_state, *service_states]) obs.extend(
[hardware_state, software_state, file_system_state, *service_states]
)
self.current_observation[:] = obs self.current_observation[:] = obs
@@ -303,8 +306,6 @@ class ObservationsHandler:
self.registered_obs_components: List[AbstractObservationComponent] = [] self.registered_obs_components: List[AbstractObservationComponent] = []
self.space: spaces.Space self.space: spaces.Space
self.current_observation: Union[Tuple[np.ndarray], np.ndarray] self.current_observation: Union[Tuple[np.ndarray], np.ndarray]
# i can access the registry items like this:
# self.registry.LINK_TRAFFIC_LEVELS
def update_obs(self): def update_obs(self):
"""Fetch fresh information about the environment.""" """Fetch fresh information about the environment."""
@@ -318,6 +319,7 @@ class ObservationsHandler:
self.current_observation = current_obs[0] self.current_observation = current_obs[0]
else: else:
self.current_observation = tuple(current_obs) self.current_observation = tuple(current_obs)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
def register(self, obs_component: AbstractObservationComponent): def register(self, obs_component: AbstractObservationComponent):
"""Add a component for this handler to track. """Add a component for this handler to track.
@@ -349,6 +351,7 @@ class ObservationsHandler:
self.space = component_spaces[0] self.space = component_spaces[0]
else: else:
self.space = spaces.Tuple(component_spaces) self.space = spaces.Tuple(component_spaces)
# TODO: We may need to add ability to flatten the space as not all agents support tuple spaces.
@classmethod @classmethod
def from_config(cls, env: "Primaite", obs_space_config: dict): def from_config(cls, env: "Primaite", obs_space_config: dict):

View File

@@ -2,15 +2,20 @@
type: NODE type: NODE
- itemType: OBSERVATION_SPACE - itemType: OBSERVATION_SPACE
components: components:
- name: NODE_STATUSES - name: LINK_TRAFFIC_LEVELS
options:
combine_service_traffic: false
quantisation_levels: 8
- itemType: STEPS - itemType: STEPS
steps: 5 steps: 5
- itemType: PORTS - itemType: PORTS
portsList: portsList:
- port: '80' - port: '80'
- port: '53'
- itemType: SERVICES - itemType: SERVICES
serviceList: serviceList:
- name: TCP - name: TCP
- name: UDP
######################################## ########################################
# Nodes # Nodes
@@ -28,6 +33,9 @@
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE - itemType: NODE
node_id: '2' node_id: '2'
name: SERVER name: SERVER
@@ -42,6 +50,9 @@
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE - itemType: NODE
node_id: '3' node_id: '3'
name: SWITCH1 name: SWITCH1
@@ -67,3 +78,33 @@
bandwidth: 1000 bandwidth: 1000
source: '3' source: '3'
destination: '2' destination: '2'
#########################################
# IERS
- itemType: GREEN_IER
id: '5'
startStep: 0
endStep: 5
load: 20
protocol: TCP
port: '80'
source: '1'
destination: '2'
missionCriticality: 5
#########################################
# ACL Rules
- itemType: ACL_RULE
id: '6'
permission: ALLOW
source: 192.168.1.1
destination: 192.168.1.2
protocol: TCP
port: 80
- itemType: ACL_RULE
id: '7'
permission: ALLOW
source: 192.168.1.2
destination: 192.168.1.1
protocol: TCP
port: 80

View File

@@ -3,17 +3,16 @@
- itemType: OBSERVATION_SPACE - itemType: OBSERVATION_SPACE
components: components:
- name: NODE_LINK_TABLE - name: NODE_LINK_TABLE
options:
- combine_service_traffic: false
- quantisation_levels: 8
- itemType: STEPS - itemType: STEPS
steps: 5 steps: 5
- itemType: PORTS - itemType: PORTS
portsList: portsList:
- port: '80' - port: '80'
- port: '53'
- itemType: SERVICES - itemType: SERVICES
serviceList: serviceList:
- name: TCP - name: TCP
- name: UDP
######################################## ########################################
# Nodes # Nodes
@@ -31,6 +30,9 @@
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE - itemType: NODE
node_id: '2' node_id: '2'
name: SERVER name: SERVER
@@ -45,6 +47,9 @@
- name: TCP - name: TCP
port: '80' port: '80'
state: GOOD state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE - itemType: NODE
node_id: '3' node_id: '3'
name: SWITCH1 name: SWITCH1

View File

@@ -0,0 +1,107 @@
- itemType: ACTIONS
type: NODE
- itemType: OBSERVATION_SPACE
components:
- name: NODE_STATUSES
- itemType: STEPS
steps: 5
- itemType: PORTS
portsList:
- port: '80'
- port: '53'
- itemType: SERVICES
serviceList:
- name: TCP
- name: UDP
########################################
# Nodes
- itemType: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: COMPROMISED
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: OVERWHELMED
- itemType: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
########################################
# Links
- itemType: LINK
id: '4'
name: link1
bandwidth: 1000
source: '1'
destination: '3'
- itemType: LINK
id: '5'
name: link2
bandwidth: 1000
source: '3'
destination: '2'
#########################################
# IERS
- itemType: GREEN_IER
id: '5'
startStep: 0
endStep: 5
load: 20
protocol: TCP
port: '80'
source: '1'
destination: '2'
missionCriticality: 5
#########################################
# ACL Rules
- itemType: ACL_RULE
id: '6'
permission: ALLOW
source: 192.168.1.1
destination: 192.168.1.2
protocol: TCP
port: 80
- itemType: ACL_RULE
id: '7'
permission: ALLOW
source: 192.168.1.2
destination: 192.168.1.1
protocol: TCP
port: 80

View File

@@ -0,0 +1,74 @@
- itemType: ACTIONS
type: NODE
- itemType: STEPS
steps: 5
- itemType: PORTS
portsList:
- port: '80'
- port: '53'
- itemType: SERVICES
serviceList:
- name: TCP
- name: UDP
########################################
# Nodes
- itemType: NODE
node_id: '1'
name: PC1
node_class: SERVICE
node_type: COMPUTER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.1
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE
node_id: '2'
name: SERVER
node_class: SERVICE
node_type: SERVER
priority: P5
hardware_state: 'ON'
ip_address: 192.168.1.2
software_state: GOOD
file_system_state: GOOD
services:
- name: TCP
port: '80'
state: GOOD
- name: UDP
port: '53'
state: GOOD
- itemType: NODE
node_id: '3'
name: SWITCH1
node_class: ACTIVE
node_type: SWITCH
priority: P2
hardware_state: 'ON'
ip_address: 192.168.1.3
software_state: GOOD
file_system_state: GOOD
########################################
# Links
- itemType: LINK
id: '4'
name: link1
bandwidth: 1000
source: '1'
destination: '3'
- itemType: LINK
id: '5'
name: link2
bandwidth: 1000
source: '3'
destination: '2'

View File

@@ -0,0 +1,89 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agentIdentifier: NONE
# Number of episodes to run per session
numEpisodes: 1
# Time delay between steps (for generic agents)
timeDelay: 1
# Filename of the scenario / laydown
configFilename: one_node_states_on_off_lay_down_config.yaml
# Type of session to be run (TRAINING or EVALUATION)
sessionType: TRAINING
# Determine whether to load an agent from file
loadAgent: False
# File path and file name of agent if you're loading one in
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
# Environment config values
# The high value for the observation space
observationSpaceHighValue: 1000000000
# Reward values
# Generic
allOk: 0
# Node Hardware State
offShouldBeOn: -10
offShouldBeResetting: -5
onShouldBeOff: -2
onShouldBeResetting: -5
resettingShouldBeOn: -5
resettingShouldBeOff: -2
resetting: -3
# Node Software or Service State
goodShouldBePatching: 2
goodShouldBeCompromised: 5
goodShouldBeOverwhelmed: 5
patchingShouldBeGood: -5
patchingShouldBeCompromised: 2
patchingShouldBeOverwhelmed: 2
patching: -3
compromisedShouldBeGood: -20
compromisedShouldBePatching: -20
compromisedShouldBeOverwhelmed: -20
compromised: -20
overwhelmedShouldBeGood: -20
overwhelmedShouldBePatching: -20
overwhelmedShouldBeCompromised: -20
overwhelmed: -20
# Node File System State
goodShouldBeRepairing: 2
goodShouldBeRestoring: 2
goodShouldBeCorrupt: 5
goodShouldBeDestroyed: 10
repairingShouldBeGood: -5
repairingShouldBeRestoring: 2
repairingShouldBeCorrupt: 2
repairingShouldBeDestroyed: 0
repairing: -3
restoringShouldBeGood: -10
restoringShouldBeRepairing: -2
restoringShouldBeCorrupt: 1
restoringShouldBeDestroyed: 2
restoring: -6
corruptShouldBeGood: -10
corruptShouldBeRepairing: -10
corruptShouldBeRestoring: -10
corruptShouldBeDestroyed: 2
corrupt: -10
destroyedShouldBeGood: -20
destroyedShouldBeRepairing: -20
destroyedShouldBeRestoring: -20
destroyedShouldBeCorrupt: -20
destroyed: -20
scanning: -2
# IER status
redIerRunning: -5
greenIerBlocked: -10
# Patching / Reset durations
osPatchingDuration: 5 # The time taken to patch the OS
nodeResetDuration: 5 # The time taken to reset a node (hardware)
servicePatchingDuration: 5 # The time taken to patch a service
fileSystemRepairingLimit: 5 # The time take to repair the file system
fileSystemRestoringLimit: 5 # The time take to restore the file system
fileSystemScanningLimit: 5 # The time taken to scan the file system

View File

@@ -1,45 +1,168 @@
"""Test env creation and behaviour with different observation spaces.""" """Test env creation and behaviour with different observation spaces."""
import numpy as np
import pytest
from primaite.environment.observations import NodeStatuses, ObservationsHandler from primaite.environment.observations import (
NodeLinkTable,
NodeStatuses,
ObservationsHandler,
)
from primaite.environment.primaite_env import Primaite
from tests import TEST_CONFIG_ROOT from tests import TEST_CONFIG_ROOT
from tests.conftest import _get_primaite_env_from_config from tests.conftest import _get_primaite_env_from_config
def test_creating_env_with_box_obs(): @pytest.fixture
"""Try creating env with box observation space.""" def env(request):
"""Build Primaite environment for integration tests of observation space."""
marker = request.node.get_closest_marker("env_config_paths")
main_config_path = marker.args[0]["main_config_path"]
lay_down_config_path = marker.args[0]["lay_down_config_path"]
env = _get_primaite_env_from_config( env = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", main_config_path=main_config_path,
lay_down_config_path=TEST_CONFIG_ROOT / "box_obs_space_laydown_config.yaml", lay_down_config_path=lay_down_config_path,
) )
env.update_environent_obs() yield env
# we have three nodes and two links, with one service
# therefore the box observation space will have:
# * 5 columns (four fixed and one for the service)
# * 5 rows (3 nodes + 2 links)
assert env.env_obs.shape == (5, 5)
def test_creating_env_with_multidiscrete_obs(): @pytest.mark.env_config_paths(
"""Try creating env with MultiDiscrete observation space.""" dict(
env = _get_primaite_env_from_config(
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml", main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT lay_down_config_path=TEST_CONFIG_ROOT
/ "multidiscrete_obs_space_laydown_config.yaml", / "obs_tests/laydown_without_obs_space.yaml",
) )
)
def test_default_obs_space(env: Primaite):
"""Create environment with no obs space defined in config and check that the default obs space was created."""
env.update_environent_obs() env.update_environent_obs()
# we have three nodes and two links, with one service components = env.obs_handler.registered_obs_components
# the nodes have hardware, OS, FS, and service, the links just have bandwidth,
# therefore we need 3*4 + 2 observations assert len(components) == 1
assert env.env_obs.shape == (3 * 4 + 2,) assert isinstance(components[0], NodeLinkTable)
def test_component_registration(): @pytest.mark.env_config_paths(
"""Test that we can register and deregister a component.""" dict(
main_config_path=TEST_CONFIG_ROOT / "one_node_states_on_off_main_config.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "obs_tests/laydown_without_obs_space.yaml",
)
)
def test_registering_components(env: Primaite):
"""Test regitering and deregistering a component."""
handler = ObservationsHandler() handler = ObservationsHandler()
component = NodeStatuses() component = NodeStatuses(env)
handler.register(component) handler.register(component)
assert component in handler.registered_obs_components assert component in handler.registered_obs_components
handler.deregister(component) handler.deregister(component)
assert component not in handler.registered_obs_components assert component not in handler.registered_obs_components
@pytest.mark.env_config_paths(
dict(
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "obs_tests/laydown_with_NODE_LINK_TABLE.yaml",
)
)
class TestNodeLinkTable:
"""Test the NodeLinkTable observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
"""Try creating env with box observation space."""
env.update_environent_obs()
# we have three nodes and two links, with two service
# therefore the box observation space will have:
# * 5 rows (3 nodes + 2 links)
# * 6 columns (four fixed and two for the services)
assert env.env_obs.shape == (5, 6)
# def test_value(self, env: Primaite):
# """"""
# ...
@pytest.mark.env_config_paths(
dict(
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "obs_tests/laydown_with_NODE_STATUSES.yaml",
)
)
class TestNodeStatuses:
"""Test the NodeStatuses observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
"""Try creating env with NodeStatuses as the only component."""
assert env.env_obs.shape == (15)
def test_values(self, env: Primaite):
"""Test that the hardware and software states are encoded correctly.
The laydown has:
* one node with a compromised operating system state
* one node with two services, and the second service is overwhelmed.
* all other states are good or null
Therefore, the expected state is:
* node 1:
* hardware = good (1)
* OS = compromised (3)
* file system = good (1)
* service 1 = good (1)
* service 2 = good (1)
* node 2:
* hardware = good (1)
* OS = good (1)
* file system = good (1)
* service 1 = good (1)
* service 2 = overwhelmed (4)
* node 3 (switch):
* hardware = good (1)
* OS = good (1)
* file system = good (1)
* service 1 = n/a (0)
* service 2 = n/a (0)
"""
act = np.asarray([0, 0, 0, 0])
obs, _, _, _ = env.step(act)
assert np.array_equal(obs, [1, 3, 1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 0, 0])
@pytest.mark.env_config_paths(
dict(
main_config_path=TEST_CONFIG_ROOT / "obs_tests/main_config_no_agent.yaml",
lay_down_config_path=TEST_CONFIG_ROOT
/ "obs_tests/laydown_with_LINK_TRAFFIC_LEVELS.yaml",
)
)
class TestLinkTrafficLevels:
"""Test the LinkTrafficLevels observation component (in isolation)."""
def test_obs_shape(self, env: Primaite):
"""Try creating env with MultiDiscrete observation space."""
env.update_environent_obs()
# we have two links and two services, so the shape should be 2 * 2
assert env.env_obs.shape == (2 * 2,)
def test_values(self, env: Primaite):
"""Test that traffic values are encoded correctly.
The laydown has:
* two services
* three nodes
* two links
* an IER trying to send 20 bits of data over both links the whole time (via the first service)
* link bandwidth of 1000, therefore the utilisation is 2%
"""
act = np.asarray([0, 0, 0, 0])
obs, reward, done, info = env.step(act)
obs, reward, done, info = env.step(act)
# the observation space has combine_service_traffic set to False, so the space has this format:
# [link1_service1, link1_service2, link2_service1, link2_service2]
# we send 20 bits of data via link1 and link2 on service 1.
# therefore the first and third elements should be 1 and all others 0
assert np.array_equal(obs, [1, 0, 1, 0])