diff --git a/src/primaite/config/training_config.py b/src/primaite/config/training_config.py index 9a21d087..14102432 100644 --- a/src/primaite/config/training_config.py +++ b/src/primaite/config/training_config.py @@ -1,4 +1,5 @@ # Crown Copyright (C) Dstl 2022. DEFCON 703. Shared in confidence. +import logging from dataclasses import dataclass, field from pathlib import Path from typing import Any, Dict, Final, Optional, Union @@ -9,6 +10,7 @@ from primaite import USERS_CONFIG_DIR, getLogger from primaite.common.enums import ActionType _LOGGER = getLogger(__name__) +logging.basicConfig(level=logging.DEBUG, format="%(message)s") _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training" diff --git a/src/primaite/environment/observations.py b/src/primaite/environment/observations.py index 96df1f60..fe43c9e3 100644 --- a/src/primaite/environment/observations.py +++ b/src/primaite/environment/observations.py @@ -331,7 +331,6 @@ class AccessControlList(AbstractObservationComponent): """ # Terms (for ACL observation space): - # [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule) # [0, 1] - Permission (0 = DENY, 1 = ALLOW) # [0, num nodes] - Source IP (0 = any, then 1 -> x resolving to IP addresses) # [0, num nodes] - Dest IP (0 = any, then 1 -> x resolving to IP addresses) @@ -352,7 +351,7 @@ class AccessControlList(AbstractObservationComponent): len(env.services_list), len(env.ports_list), ] - shape = acl_shape * self.env.max_acl_rules + shape = acl_shape * self.env.max_number_acl_rules # 2. Create Observation space self.space = spaces.MultiDiscrete(shape) @@ -360,9 +359,6 @@ class AccessControlList(AbstractObservationComponent): # 3. Initialise observation with zeroes self.current_observation = np.zeros(len(shape), dtype=self._DATA_TYPE) - # Dictionary to map services to numbers for obs space - self.services_dict = {} - def update(self): """Update the observation based on current environment state. @@ -380,7 +376,6 @@ class AccessControlList(AbstractObservationComponent): permission_int = 0 else: permission_int = 1 - if source_ip == "ANY": source_ip_int = 0 else: @@ -393,9 +388,10 @@ class AccessControlList(AbstractObservationComponent): protocol_int = 0 else: try: - protocol_int = Protocol[protocol] + protocol_int = Protocol[protocol].value except AttributeError: _LOGGER.info(f"Service {protocol} could not be found") + protocol_int = -1 if port == "ANY": port_int = 0 else: @@ -423,7 +419,8 @@ class AccessControlList(AbstractObservationComponent): Resolves IP address -> x (node id e.g. 1 or 2 or 3 or 4) for observation space """ - for key, node in self.env.nodes: + print(type(self.env.nodes)) + for key, node in self.env.nodes.items(): if isinstance(node, ActiveNode) or isinstance(node, ServiceNode): if node.ip_address == ip_address: return key diff --git a/src/primaite/environment/primaite_env.py b/src/primaite/environment/primaite_env.py index 39006259..783b4267 100644 --- a/src/primaite/environment/primaite_env.py +++ b/src/primaite/environment/primaite_env.py @@ -122,6 +122,9 @@ class Primaite(Env): self.training_config.implicit_acl_rule, self.training_config.max_number_acl_rules, ) + # Sets limit for number of ACL rules in environment + self.max_number_acl_rules = self.training_config.max_number_acl_rules + # Create a list of services (enums) self.services_list = [] diff --git a/tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml b/tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml index 856e963d..7aa30205 100644 --- a/tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml +++ b/tests/config/obs_tests/main_config_ACCESS_CONTROL_LIST.yaml @@ -5,14 +5,16 @@ # "STABLE_BASELINES3_PPO" # "STABLE_BASELINES3_A2C" # "GENERIC" -agentIdentifier: NONE +agent_identifier: STABLE_BASELINES3_A2C +# Sets How the Action Space is defined: +# "NODE" +# "ACL" +# "ANY" node and acl actions +action_type: ANY # Number of episodes to run per session -observationSpace: - components: - - name: ACCESS_CONTROL_LIST - options: - implicit_acl_rule: DENY - max_number_of_acl_rules: 10 +num_episodes: 1 +# Number of time_steps per episode +num_steps: 5 # Choice whether to have an ALLOW or DENY implicit rule or not (TRUE or FALSE) apply_implicit_rule: True @@ -20,83 +22,86 @@ apply_implicit_rule: True implicit_acl_rule: DENY # Total number of ACL rules allowed in the environment max_number_acl_rules: 10 -numEpisodes: 1 + +observation_space: + components: + - name: ACCESS_CONTROL_LIST + # Time delay between steps (for generic agents) -timeDelay: 1 -# Filename of the scenario / laydown -configFilename: one_node_states_on_off_lay_down_config.yaml +time_delay: 1 + # Type of session to be run (TRAINING or EVALUATION) -sessionType: TRAINING +session_type: TRAINING # Determine whether to load an agent from file -loadAgent: False +load_agent: False # File path and file name of agent if you're loading one in -agentLoadFile: C:\[Path]\[agent_saved_filename.zip] +agent_load_file: C:\[Path]\[agent_saved_filename.zip] # Environment config values # The high value for the observation space -observationSpaceHighValue: 1_000_000_000 +observation_space_high_value: 1_000_000_000 # Reward values # Generic -allOk: 0 +all_ok: 0 # Node Hardware State -offShouldBeOn: -10 -offShouldBeResetting: -5 -onShouldBeOff: -2 -onShouldBeResetting: -5 -resettingShouldBeOn: -5 -resettingShouldBeOff: -2 +off_should_be_on: -10 +off_should_be_resetting: -5 +on_should_be_off: -2 +on_should_be_resetting: -5 +resetting_should_be_on: -5 +resetting_should_be_off: -2 resetting: -3 # Node Software or Service State -goodShouldBePatching: 2 -goodShouldBeCompromised: 5 -goodShouldBeOverwhelmed: 5 -patchingShouldBeGood: -5 -patchingShouldBeCompromised: 2 -patchingShouldBeOverwhelmed: 2 +good_should_be_patching: 2 +good_should_be_compromised: 5 +good_should_be_overwhelmed: 5 +patching_should_be_good: -5 +patching_should_be_compromised: 2 +patching_should_be_overwhelmed: 2 patching: -3 -compromisedShouldBeGood: -20 -compromisedShouldBePatching: -20 -compromisedShouldBeOverwhelmed: -20 +compromised_should_be_good: -20 +compromised_should_be_patching: -20 +compromised_should_be_overwhelmed: -20 compromised: -20 -overwhelmedShouldBeGood: -20 -overwhelmedShouldBePatching: -20 -overwhelmedShouldBeCompromised: -20 +overwhelmed_should_be_good: -20 +overwhelmed_should_be_patching: -20 +overwhelmed_should_be_compromised: -20 overwhelmed: -20 # Node File System State -goodShouldBeRepairing: 2 -goodShouldBeRestoring: 2 -goodShouldBeCorrupt: 5 -goodShouldBeDestroyed: 10 -repairingShouldBeGood: -5 -repairingShouldBeRestoring: 2 -repairingShouldBeCorrupt: 2 -repairingShouldBeDestroyed: 0 +good_should_be_repairing: 2 +good_should_be_restoring: 2 +good_should_be_corrupt: 5 +good_should_be_destroyed: 10 +repairing_should_be_good: -5 +repairing_should_be_restoring: 2 +repairing_should_be_corrupt: 2 +repairing_should_be_destroyed: 0 repairing: -3 -restoringShouldBeGood: -10 -restoringShouldBeRepairing: -2 -restoringShouldBeCorrupt: 1 -restoringShouldBeDestroyed: 2 +restoring_should_be_good: -10 +restoring_should_be_repairing: -2 +restoring_should_be_corrupt: 1 +restoring_should_be_destroyed: 2 restoring: -6 -corruptShouldBeGood: -10 -corruptShouldBeRepairing: -10 -corruptShouldBeRestoring: -10 -corruptShouldBeDestroyed: 2 +corrupt_should_be_good: -10 +corrupt_should_be_repairing: -10 +corrupt_should_be_restoring: -10 +corrupt_should_be_destroyed: 2 corrupt: -10 -destroyedShouldBeGood: -20 -destroyedShouldBeRepairing: -20 -destroyedShouldBeRestoring: -20 -destroyedShouldBeCorrupt: -20 +destroyed_should_be_good: -20 +destroyed_should_be_repairing: -20 +destroyed_should_be_restoring: -20 +destroyed_should_be_corrupt: -20 destroyed: -20 scanning: -2 # IER status -redIerRunning: -5 -greenIerBlocked: -10 +red_ier_running: -5 +green_ier_blocked: -10 # Patching / Reset durations -osPatchingDuration: 5 # The time taken to patch the OS -nodeResetDuration: 5 # The time taken to reset a node (hardware) -servicePatchingDuration: 5 # The time taken to patch a service -fileSystemRepairingLimit: 5 # The time take to repair the file system -fileSystemRestoringLimit: 5 # The time take to restore the file system -fileSystemScanningLimit: 5 # The time taken to scan the file system +os_patching_duration: 5 # The time taken to patch the OS +node_reset_duration: 5 # The time taken to reset a node (hardware) +service_patching_duration: 5 # The time taken to patch a service +file_system_repairing_limit: 5 # The time take to repair the file system +file_system_restoring_limit: 5 # The time take to restore the file system +file_system_scanning_limit: 5 # The time taken to scan the file system diff --git a/tests/test_observation_space.py b/tests/test_observation_space.py index 4e7186b5..4e8df7e1 100644 --- a/tests/test_observation_space.py +++ b/tests/test_observation_space.py @@ -222,3 +222,42 @@ class TestLinkTrafficLevels: # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) print(obs) assert np.array_equal(obs, [6, 0, 6, 0]) + + +@pytest.mark.env_config_paths( + dict( + training_config_path=TEST_CONFIG_ROOT + / "obs_tests/main_config_ACCESS_CONTROL_LIST.yaml", + lay_down_config_path=TEST_CONFIG_ROOT / "obs_tests/laydown.yaml", + ) +) +class TestAccessControlList: + """Test the AccessControlList observation component (in isolation).""" + + def test_obs_shape(self, env: Primaite): + """Try creating env with MultiDiscrete observation space.""" + env.update_environent_obs() + + # we have two ACLs + assert env.env_obs.shape == (5, 2) + + def test_values(self, env: Primaite): + """Test that traffic values are encoded correctly. + + The laydown has: + * two services + * three nodes + * two links + * an IER trying to send 999 bits of data over both links the whole time (via the first service) + * link bandwidth of 1000, therefore the utilisation is 99.9% + """ + obs, reward, done, info = env.step(0) + obs, reward, done, info = env.step(0) + + # the observation space has combine_service_traffic set to False, so the space has this format: + # [link1_service1, link1_service2, link2_service1, link2_service2] + # we send 999 bits of data via link1 and link2 on service 1. + # therefore the first and third elements should be 6 and all others 0 + # (`7` corresponds to 100% utiilsation and `6` corresponds to 87.5%-100%) + print(obs) + assert np.array_equal(obs, [6, 0, 6, 0])