901 - started testing for observation space
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
# Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence.
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Final, Optional, Union
|
||||
@@ -9,6 +10,7 @@ from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import ActionType
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
logging.basicConfig(level=logging.DEBUG, format="%(message)s")
|
||||
|
||||
_EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training"
|
||||
|
||||
|
||||
@@ -331,7 +331,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
"""
|
||||
|
||||
# Terms (for ACL observation space):
|
||||
# [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)
|
||||
# [0, 1] - Permission (0 = DENY, 1 = ALLOW)
|
||||
# [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
# [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses)
|
||||
@@ -352,7 +351,7 @@ class AccessControlList(AbstractObservationComponent):
|
||||
len(env.services_list),
|
||||
len(env.ports_list),
|
||||
]
|
||||
shape = acl_shape * self.env.max_acl_rules
|
||||
shape = acl_shape * self.env.max_number_acl_rules
|
||||
|
||||
# 2. Create Observation space
|
||||
self.space = spaces.MultiDiscrete(shape)
|
||||
@@ -360,9 +359,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
# 3. Initialise observation with zeroes
|
||||
self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE)
|
||||
|
||||
# Dictionary to map services to numbers for obs space
|
||||
self.services_dict = {}
|
||||
|
||||
def update(self):
|
||||
"""Update the observation based on current environment state.
|
||||
|
||||
@@ -380,7 +376,6 @@ class AccessControlList(AbstractObservationComponent):
|
||||
permission_int = 0
|
||||
else:
|
||||
permission_int = 1
|
||||
|
||||
if source_ip == "ANY":
|
||||
source_ip_int = 0
|
||||
else:
|
||||
@@ -393,9 +388,10 @@ class AccessControlList(AbstractObservationComponent):
|
||||
protocol_int = 0
|
||||
else:
|
||||
try:
|
||||
protocol_int = Protocol[protocol]
|
||||
protocol_int = Protocol[protocol].value
|
||||
except AttributeError:
|
||||
_LOGGER.info(f"Service {protocol} could not be found")
|
||||
protocol_int = -1
|
||||
if port == "ANY":
|
||||
port_int = 0
|
||||
else:
|
||||
@@ -423,7 +419,8 @@ class AccessControlList(AbstractObservationComponent):
|
||||
|
||||
Resolves IP address -> x (node id e.g. 1 or 2 or 3 or 4) for observation space
|
||||
"""
|
||||
for key, node in self.env.nodes:
|
||||
print(type(self.env.nodes))
|
||||
for key, node in self.env.nodes.items():
|
||||
if isinstance(node, ActiveNode) or isinstance(node, ServiceNode):
|
||||
if node.ip_address == ip_address:
|
||||
return key
|
||||
|
||||
@@ -122,6 +122,9 @@ class Primaite(Env):
|
||||
self.training_config.implicit_acl_rule,
|
||||
self.training_config.max_number_acl_rules,
|
||||
)
|
||||
# Sets limit for number of ACL rules in environment
|
||||
self.max_number_acl_rules = self.training_config.max_number_acl_rules
|
||||
|
||||
# Create a list of services (enums)
|
||||
self.services_list = []
|
||||
|
||||
|
||||
@@ -5,14 +5,16 @@
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agentIdentifier: NONE
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: ANY
|
||||
# Number of episodes to run per session
|
||||
observationSpace:
|
||||
components:
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
options:
|
||||
implicit_acl_rule: DENY
|
||||
max_number_of_acl_rules: 10
|
||||
num_episodes: 1
|
||||
# Number of time_steps per episode
|
||||
num_steps: 5
|
||||
|
||||
# Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE)
|
||||
apply_implicit_rule: True
|
||||
@@ -20,83 +22,86 @@ apply_implicit_rule: True
|
||||
implicit_acl_rule: DENY
|
||||
# Total number of ACL rules allowed in the environment
|
||||
max_number_acl_rules: 10
|
||||
numEpisodes: 1
|
||||
|
||||
observation_space:
|
||||
components:
|
||||
- name: ACCESS_CONTROL_LIST
|
||||
|
||||
# Time delay between steps (for generic agents)
|
||||
timeDelay: 1
|
||||
# Filename of the scenario / laydown
|
||||
configFilename: one_node_states_on_off_lay_down_config.yaml
|
||||
time_delay: 1
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
sessionType: TRAINING
|
||||
session_type: TRAINING
|
||||
# Determine whether to load an agent from file
|
||||
loadAgent: False
|
||||
load_agent: False
|
||||
# File path and file name of agent if you're loading one in
|
||||
agentLoadFile: C:\[Path]\[agent_saved_filename.zip]
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
# Environment config values
|
||||
# The high value for the observation space
|
||||
observationSpaceHighValue: 1_000_000_000
|
||||
observation_space_high_value: 1_000_000_000
|
||||
|
||||
# Reward values
|
||||
# Generic
|
||||
allOk: 0
|
||||
all_ok: 0
|
||||
# Node Hardware State
|
||||
offShouldBeOn: -10
|
||||
offShouldBeResetting: -5
|
||||
onShouldBeOff: -2
|
||||
onShouldBeResetting: -5
|
||||
resettingShouldBeOn: -5
|
||||
resettingShouldBeOff: -2
|
||||
off_should_be_on: -10
|
||||
off_should_be_resetting: -5
|
||||
on_should_be_off: -2
|
||||
on_should_be_resetting: -5
|
||||
resetting_should_be_on: -5
|
||||
resetting_should_be_off: -2
|
||||
resetting: -3
|
||||
# Node Software or Service State
|
||||
goodShouldBePatching: 2
|
||||
goodShouldBeCompromised: 5
|
||||
goodShouldBeOverwhelmed: 5
|
||||
patchingShouldBeGood: -5
|
||||
patchingShouldBeCompromised: 2
|
||||
patchingShouldBeOverwhelmed: 2
|
||||
good_should_be_patching: 2
|
||||
good_should_be_compromised: 5
|
||||
good_should_be_overwhelmed: 5
|
||||
patching_should_be_good: -5
|
||||
patching_should_be_compromised: 2
|
||||
patching_should_be_overwhelmed: 2
|
||||
patching: -3
|
||||
compromisedShouldBeGood: -20
|
||||
compromisedShouldBePatching: -20
|
||||
compromisedShouldBeOverwhelmed: -20
|
||||
compromised_should_be_good: -20
|
||||
compromised_should_be_patching: -20
|
||||
compromised_should_be_overwhelmed: -20
|
||||
compromised: -20
|
||||
overwhelmedShouldBeGood: -20
|
||||
overwhelmedShouldBePatching: -20
|
||||
overwhelmedShouldBeCompromised: -20
|
||||
overwhelmed_should_be_good: -20
|
||||
overwhelmed_should_be_patching: -20
|
||||
overwhelmed_should_be_compromised: -20
|
||||
overwhelmed: -20
|
||||
# Node File System State
|
||||
goodShouldBeRepairing: 2
|
||||
goodShouldBeRestoring: 2
|
||||
goodShouldBeCorrupt: 5
|
||||
goodShouldBeDestroyed: 10
|
||||
repairingShouldBeGood: -5
|
||||
repairingShouldBeRestoring: 2
|
||||
repairingShouldBeCorrupt: 2
|
||||
repairingShouldBeDestroyed: 0
|
||||
good_should_be_repairing: 2
|
||||
good_should_be_restoring: 2
|
||||
good_should_be_corrupt: 5
|
||||
good_should_be_destroyed: 10
|
||||
repairing_should_be_good: -5
|
||||
repairing_should_be_restoring: 2
|
||||
repairing_should_be_corrupt: 2
|
||||
repairing_should_be_destroyed: 0
|
||||
repairing: -3
|
||||
restoringShouldBeGood: -10
|
||||
restoringShouldBeRepairing: -2
|
||||
restoringShouldBeCorrupt: 1
|
||||
restoringShouldBeDestroyed: 2
|
||||
restoring_should_be_good: -10
|
||||
restoring_should_be_repairing: -2
|
||||
restoring_should_be_corrupt: 1
|
||||
restoring_should_be_destroyed: 2
|
||||
restoring: -6
|
||||
corruptShouldBeGood: -10
|
||||
corruptShouldBeRepairing: -10
|
||||
corruptShouldBeRestoring: -10
|
||||
corruptShouldBeDestroyed: 2
|
||||
corrupt_should_be_good: -10
|
||||
corrupt_should_be_repairing: -10
|
||||
corrupt_should_be_restoring: -10
|
||||
corrupt_should_be_destroyed: 2
|
||||
corrupt: -10
|
||||
destroyedShouldBeGood: -20
|
||||
destroyedShouldBeRepairing: -20
|
||||
destroyedShouldBeRestoring: -20
|
||||
destroyedShouldBeCorrupt: -20
|
||||
destroyed_should_be_good: -20
|
||||
destroyed_should_be_repairing: -20
|
||||
destroyed_should_be_restoring: -20
|
||||
destroyed_should_be_corrupt: -20
|
||||
destroyed: -20
|
||||
scanning: -2
|
||||
# IER status
|
||||
redIerRunning: -5
|
||||
greenIerBlocked: -10
|
||||
red_ier_running: -5
|
||||
green_ier_blocked: -10
|
||||
|
||||
# Patching / Reset durations
|
||||
osPatchingDuration: 5 # The time taken to patch the OS
|
||||
nodeResetDuration: 5 # The time taken to reset a node (hardware)
|
||||
servicePatchingDuration: 5 # The time taken to patch a service
|
||||
fileSystemRepairingLimit: 5 # The time take to repair the file system
|
||||
fileSystemRestoringLimit: 5 # The time take to restore the file system
|
||||
fileSystemScanningLimit: 5 # The time taken to scan the file system
|
||||
os_patching_duration: 5 # The time taken to patch the OS
|
||||
node_reset_duration: 5 # The time taken to reset a node (hardware)
|
||||
service_patching_duration: 5 # The time taken to patch a service
|
||||
file_system_repairing_limit: 5 # The time take to repair the file system
|
||||
file_system_restoring_limit: 5 # The time take to restore the file system
|
||||
file_system_scanning_limit: 5 # The time taken to scan the file system
|
||||
|
||||
@@ -222,3 +222,42 @@ class TestLinkTrafficLevels:
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
print(obs)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
|
||||
@pytest.mark.env_config_paths(
|
||||
dict(
|
||||
training_config_path=TEST_CONFIG_ROOT
|
||||
/ "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml",
|
||||
lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml",
|
||||
)
|
||||
)
|
||||
class TestAccessControlList:
|
||||
"""Test the AccessControlList observation component (in isolation)."""
|
||||
|
||||
def test_obs_shape(self, env: Primaite):
|
||||
"""Try creating env with MultiDiscrete observation space."""
|
||||
env.update_environent_obs()
|
||||
|
||||
# we have two ACLs
|
||||
assert env.env_obs.shape == (5, 2)
|
||||
|
||||
def test_values(self, env: Primaite):
|
||||
"""Test that traffic values are encoded correctly.
|
||||
|
||||
The laydown has:
|
||||
* two services
|
||||
* three nodes
|
||||
* two links
|
||||
* an IER trying to send 999 bits of data over both links the whole time (via the first service)
|
||||
* link bandwidth of 1000, therefore the utilisation is 99.9%
|
||||
"""
|
||||
obs, reward, done, info = env.step(0)
|
||||
obs, reward, done, info = env.step(0)
|
||||
|
||||
# the observation space has combine_service_traffic set to False, so the space has this format:
|
||||
# [link1_service1, link1_service2, link2_service1, link2_service2]
|
||||
# we send 999 bits of data via link1 and link2 on service 1.
|
||||
# therefore the first and third elements should be 6 and all others 0
|
||||
# (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%)
|
||||
print(obs)
|
||||
assert np.array_equal(obs, [6, 0, 6, 0])
|
||||
|
||||
Reference in New Issue
Block a user