diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 90bca35f..9b9434af 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -43,6 +43,26 @@ class FileObservation(AbstractObservation, identifier="FILE"): if self.include_num_access: self.default_observation["num_access"] = 0 + # TODO: allow these to be configured in yaml + self.high_threshold = 10 + self.med_threshold = 5 + self.low_threshold = 0 + + def _categorise_num_access(self, num_access: int) -> int: + """ + Represent number of file accesses as a categorical variable. + + :param num_access: Number of file accesses. + :return: Bin number corresponding to the number of accesses. + """ + if num_access > self.high_threshold: + return 3 + elif num_access > self.med_threshold: + return 2 + elif num_access > self.low_threshold: + return 1 + return 0 + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -57,8 +77,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): return self.default_observation obs = {"health_status": file_state["visible_status"]} if self.include_num_access: - obs["num_access"] = file_state["num_access"] - # raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.") + obs["num_access"] = self._categorise_num_access(file_state["num_access"]) return obs @property diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index ab48e606..0c10a8d2 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -214,9 +214,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Constructed firewall observation instance. :rtype: FirewallObservation """ - where = parent_where + ["nodes", config.hostname] return cls( - where=where, + where=parent_where + ["nodes", config.hostname], ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 3703fa1c..1d428fa8 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -185,9 +185,11 @@ class ObservationManager: Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. - It should contain the key 'type' which selects which observation class to use (from a choice of: - UC2BlueObservation, UC2RedObservation, UC2GreenObservation) - The other key is 'options' which are passed to the constructor of the selected observation class. + If None, a blank observation space is created. + Otherwise, this must be a Dict with a type field and options field. + type: string that corresponds to one of the observation identifiers that are provided when subclassing + AbstractObservation + options: this must adhere to the chosen observation type's ConfigSchema nested class. :type config: Dict :param game: Reference to the PrimaiteGame object that spawned this observation. :type game: PrimaiteGame diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 40788760..2c4806d9 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -98,6 +98,26 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): self.where = where self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} + # TODO: allow these to be configured in yaml + self.high_threshold = 10 + self.med_threshold = 5 + self.low_threshold = 0 + + def _categorise_num_executions(self, num_executions: int) -> int: + """ + Represent number of file accesses as a categorical variable. + + :param num_access: Number of file accesses. + :return: Bin number corresponding to the number of accesses. + """ + if num_executions > self.high_threshold: + return 3 + elif num_executions > self.med_threshold: + return 2 + elif num_executions > self.low_threshold: + return 1 + return 0 + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -113,7 +133,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): return { "operating_status": application_state["operating_state"], "health_status": application_state["health_state_visible"], - "num_executions": application_state["num_executions"], + "num_executions": self._categorise_num_executions(application_state["num_executions"]), } @property diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index b7848c53..203ea3ea 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -64,25 +64,67 @@ agents: - ref: defender team: BLUE type: ProxyAgent + observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: client_1 - links: - - link_ref: client_1___switch_1 - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: client_1 - nic_num: 1 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: client_1 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.0.10 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - client_1___switch_1 + - type: "NONE" + label: ICS + options: {} + + # observation_space: + # type: UC2BlueObservation + # options: + # num_services_per_node: 1 + # num_folders_per_node: 1 + # num_files_per_folder: 1 + # num_nics_per_node: 2 + # nodes: + # - node_hostname: client_1 + # links: + # - link_ref: client_1___switch_1 + # acl: + # options: + # max_acl_rules: 10 + # router_hostname: router_1 + # ip_address_order: + # - node_hostname: client_1 + # nic_num: 1 + # ics: null action_space: action_list: - type: DONOTHING diff --git a/tests/assets/configs/test_application_install.yaml b/tests/assets/configs/test_application_install.yaml index b3fca4bc..ccd2228c 100644 --- a/tests/assets/configs/test_application_install.yaml +++ b/tests/assets/configs/test_application_install.yaml @@ -41,8 +41,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -91,8 +90,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -141,10 +139,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -177,61 +172,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: DNSServer - - node_hostname: web_server - services: - - service_name: WebServer - - node_hostname: database_server - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index af5e9650..cb83ac5e 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -72,3 +72,6 @@ def test_folder_observation(simulation): observation_state = root_folder_obs.observe(simulation.describe_state()) assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too + + +# TODO: Add tests to check num access is correct. diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index 1a41cad4..3eee72e8 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -51,31 +51,8 @@ def simulation() -> Simulation: return sim -def test_link_observation(simulation): - """Test the link observation.""" - # get a link - link: Link = next(iter(simulation.network.links.values())) - - computer: Computer = simulation.network.get_node_by_hostname("computer") - server: Server = simulation.network.get_node_by_hostname("server") - - simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep - - link_obs = LinkObservation(where=["network", "links", link.uuid]) - - assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10 - - observation_state = link_obs.observe(simulation.describe_state()) - assert observation_state.get("PROTOCOLS") is not None - assert observation_state["PROTOCOLS"]["ALL"] == 0 - - computer.ping(server.network_interface.get(1).ip_address) - - observation_state = link_obs.observe(simulation.describe_state()) - assert observation_state["PROTOCOLS"]["ALL"] == 1 - - -def test_link_observation_again(): +def test_link_observation(): + """Check the shape and contents of the link observation.""" net = Network() sim = Simulation(network=net) switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON) @@ -102,6 +79,8 @@ def test_link_observation_again(): assert "PROTOCOLS" in link_2_obs assert "ALL" in link_1_obs["PROTOCOLS"] assert "ALL" in link_2_obs["PROTOCOLS"] + assert link_1_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) + assert link_2_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) assert link_1_obs["PROTOCOLS"]["ALL"] == 0 assert link_2_obs["PROTOCOLS"]["ALL"] == 0 diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index 0a34ab67..ed07e030 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -19,3 +19,8 @@ def test_file_observation(): ) assert dog_file_obs.observe(state) == {"health_status": 1} assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) + + +# TODO: +# def test_file_num_access(): +# ...