From fbb4eba6b74cbd766fc4c9ecb1bb19cc866c896e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 27 Mar 2024 03:27:44 +0000 Subject: [PATCH 01/16] Draft new observation space config --- .../_package_data/data_manipulation.yaml | 49 +++++++++++++++++-- 1 file changed, 45 insertions(+), 4 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 12f60b63..06028ee1 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -176,13 +176,54 @@ agents: team: BLUE type: ProxyAgent + observation_space: + - type: NODES + label: NODES # What is the dictionary key called + options: + hosts: + - hostname: domain_controller + - hostname: web_server + - hostname: database_server + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + routers: + - hostname: router_1 + firewalls: {} + + num_host_services: 1 + num_host_applications: 0 + num_host_folders: 1 + num_host_files: 1 + num_host_network_interfaces: 2 + num_router_ports: 4 + num_acl_rules: 10 + num_firewall_ports: 4 + firewalls_internal_inbound_acl: true + firewalls_internal_outbound_acl: true + firewalls_dmz_inbound_acl: true + firewalls_dmz_outbound_acl: true + firewalls_external_inbound_acl: true + firewalls_external_outbound_acl: true + - type: LINKS + label: "LINKS" + options: + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + observation_space: type: UC2BlueObservation options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 nodes: - node_hostname: domain_controller services: From cae9f64b93d62d1798fda00c6a20de6b19f25eb7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 27 Mar 2024 22:11:02 +0000 Subject: [PATCH 02/16] New observations --- .../agent/observations/agent_observations.py | 52 +- .../agent/observations/node_observations.py | 466 ++++++++++++++++- .../agent/observations/observation_manager.py | 20 +- .../game/agent/observations/observations.py | 490 +++++++++--------- src/primaite/game/agent/utils.py | 6 +- .../observations/test_node_observations.py | 2 +- 6 files changed, 727 insertions(+), 309 deletions(-) diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py index 70a83881..2148697b 100644 --- a/src/primaite/game/agent/observations/agent_observations.py +++ b/src/primaite/game/agent/observations/agent_observations.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces -from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.game.agent.observations.host import NodeObservation from primaite.game.agent.observations.observations import ( AbstractObservation, AclObservation, @@ -136,53 +136,3 @@ class UC2BlueObservation(AbstractObservation): new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) return new - -class UC2RedObservation(AbstractObservation): - """Container for all observations used by the red agent in UC2.""" - - def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: - super().__init__() - self.where: Optional[List[str]] = where - self.nodes: List[NodeObservation] = nodes - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": - """ - Create UC2 red observation from a config. - - :param config: Dictionary containing the configuration for this UC2 red observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - """ - node_configs = config["nodes"] - nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] - return cls(nodes=nodes, where=["network"]) - - -class UC2GreenObservation(NullObservation): - """Green agent observation. As the green agent's actions don't depend on the observation, this is empty.""" - - pass diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 94f0974b..42bdb749 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,21 +1,473 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from __future__ import annotations +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union from gymnasium import spaces +from gymnasium.core import ObsType +from pydantic import BaseModel, ConfigDict from primaite import getLogger -from primaite.game.agent.observations.file_system_observations import FolderObservation -from primaite.game.agent.observations.nic_observations import NicObservation from primaite.game.agent.observations.observations import AbstractObservation -from primaite.game.agent.observations.software_observation import ServiceObservation +# from primaite.game.agent.observations.file_system_observations import FolderObservation +# from primaite.game.agent.observations.nic_observations import NicObservation +# from primaite.game.agent.observations.software_observation import ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame +WhereType = Iterable[str | int] | None -class NodeObservation(AbstractObservation): +class ServiceObservation(AbstractObservation, identifier="SERVICE"): + class ConfigSchema(AbstractObservation.ConfigSchema): + service_name: str + + def __init__(self, where: WhereType)->None: + self.where = where + self.default_observation = {"operating_status": 0, "health_status": 0} + + def observe(self, state: Dict) -> Any: + service_state = access_from_nested_dict(state, self.where) + if service_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": service_state["operating_state"], + "health_status": service_state["health_state_visible"], + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + return cls(where=parent_where+["services", config.service_name]) + + +class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): + class ConfigSchema(AbstractObservation.ConfigSchema): + application_name: str + + def __init__(self, where: WhereType)->None: + self.where = where + self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} + + def observe(self, state: Dict) -> Any: + # raise NotImplementedError("TODO NUM EXECUTIONS NEEDS TO BE CONVERTED TO A CATEGORICAL") + application_state = access_from_nested_dict(state, self.where) + if application_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": application_state["operating_state"], + "health_status": application_state["health_state_visible"], + "num_executions": application_state["num_executions"], + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict({ + "operating_status": spaces.Discrete(7), + "health_status": spaces.Discrete(5), + "num_executions": spaces.Discrete(4) + }) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ApplicationObservation: + return cls(where=parent_where+["applications", config.application_name]) + + +class FileObservation(AbstractObservation, identifier="FILE"): + class ConfigSchema(AbstractObservation.ConfigSchema): + file_name: str + include_num_access : bool = False + + def __init__(self, where: WhereType, include_num_access: bool)->None: + self.where: WhereType = where + self.include_num_access :bool = include_num_access + + self.default_observation: ObsType = {"health_status": 0} + if self.include_num_access: + self.default_observation["num_access"] = 0 + + def observe(self, state: Dict) -> Any: + file_state = access_from_nested_dict(state, self.where) + if file_state is NOT_PRESENT_IN_STATE: + return self.default_observation + obs = {"health_status": file_state["visible_status"]} + if self.include_num_access: + obs["num_access"] = file_state["num_access"] + # raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.") + return obs + + @property + def space(self) -> spaces.Space: + space = {"health_status": spaces.Discrete(6)} + if self.include_num_access: + space["num_access"] = spaces.Discrete(4) + return spaces.Dict(space) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation: + return cls(where=parent_where+["files", config.file_name], include_num_access=config.include_num_access) + + +class FolderObservation(AbstractObservation, identifier="FOLDER"): + class ConfigSchema(AbstractObservation.ConfigSchema): + folder_name: str + files: List[FileObservation.ConfigSchema] = [] + num_files : int = 0 + include_num_access : bool = False + + def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None: + self.where: WhereType = where + + self.files: List[FileObservation] = files + while len(self.files) < num_files: + self.files.append(FileObservation(where=None,include_num_access=include_num_access)) + while len(self.files) > num_files: + truncated_file = self.files.pop() + msg = f"Too many files in folder observation. Truncating file {truncated_file}" + _LOGGER.warning(msg) + + self.default_observation = { + "health_status": 0, + "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, + } + + def observe(self, state: Dict) -> Any: + folder_state = access_from_nested_dict(state, self.where) + if folder_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + health_status = folder_state["health_status"] + + obs = {} + + obs["health_status"] = health_status + obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "health_status": spaces.Discrete(6), + "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), + } + ) + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation: + where = parent_where + ["folders", config.folder_name] + + #pass down shared/common config items + for file_config in config.files: + file_config.include_num_access = config.include_num_access + + files = [FileObservation.from_config(config=f, parent_where = where) for f in config.files] + return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) + + +class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): + class ConfigSchema(AbstractObservation.ConfigSchema): + nic_num: int + include_nmne: bool = False + + + def __init__(self, where: WhereType, include_nmne: bool)->None: + self.where = where + self.include_nmne : bool = include_nmne + + self.default_observation: ObsType = {"nic_status": 0} + if self.include_nmne: + self.default_observation.update({"NMNE":{"inbound":0, "outbound":0}}) + + def observe(self, state: Dict) -> Any: + # raise NotImplementedError("TODO: CATEGORISATION") + nic_state = access_from_nested_dict(state, self.where) + + if nic_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {"nic_status": 1 if nic_state["enabled"] else 2} + if self.include_nmne: + obs.update({"NMNE": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + self.nmne_inbound_last_step = inbound_count + self.nmne_outbound_last_step = outbound_count + return obs + + + @property + def space(self) -> spaces.Space: + space = spaces.Dict({"nic_status": spaces.Discrete(3)}) + + if self.include_nmne: + space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) + + return space + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne) + + +class HostObservation(AbstractObservation, identifier="HOST"): + class ConfigSchema(AbstractObservation.ConfigSchema): + hostname: str + services: List[ServiceObservation.ConfigSchema] = [] + applications: List[ApplicationObservation.ConfigSchema] = [] + folders: List[FolderObservation.ConfigSchema] = [] + network_interfaces: List[NICObservation.ConfigSchema] = [] + num_services: int + num_applications: int + num_folders: int + num_files: int + num_nics: int + include_nmne: bool + include_num_access: bool + + def __init__(self, + where: WhereType, + services:List[ServiceObservation], + applications:List[ApplicationObservation], + folders:List[FolderObservation], + network_interfaces:List[NICObservation], + num_services: int, + num_applications: int, + num_folders: int, + num_files: int, + num_nics: int, + include_nmne: bool, + include_num_access: bool + )->None: + + self.where : WhereType = where + + # ensure service list has length equal to num_services by truncating or padding + self.services: List[ServiceObservation] = services + while len(self.services) < num_services: + self.services.append(ServiceObservation(where=None)) + while len(self.services) > num_services: + truncated_service = self.services.pop() + msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" + _LOGGER.warning(msg) + + # ensure application list has length equal to num_applications by truncating or padding + self.applications: List[ApplicationObservation] = applications + while len(self.applications) < num_applications: + self.applications.append(ApplicationObservation(where=None)) + while len(self.applications) > num_applications: + truncated_application = self.applications.pop() + msg = f"Too many applications in Node observation space for node. Truncating application {truncated_application.where}" + _LOGGER.warning(msg) + + # ensure folder list has length equal to num_folders by truncating or padding + self.folders: List[FolderObservation] = folders + while len(self.folders) < num_folders: + self.folders.append(FolderObservation(where = None, files= [], num_files=num_files, include_num_access=include_num_access)) + while len(self.folders) > num_folders: + truncated_folder = self.folders.pop() + msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}" + _LOGGER.warning(msg) + + # ensure network_interface list has length equal to num_network_interfaces by truncating or padding + self.network_interfaces: List[NICObservation] = network_interfaces + while len(self.network_interfaces) < num_nics: + self.network_interfaces.append(NICObservation(where = None, include_nmne=include_nmne)) + while len(self.network_interfaces) > num_nics: + truncated_nic = self.network_interfaces.pop() + msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_folder.where}" + _LOGGER.warning(msg) + + self.default_observation: ObsType = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "operating_status": 0, + "num_file_creations": 0, + "num_file_deletions": 0, + } + + + def observe(self, state: Dict) -> Any: + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NICS"] = { + i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) + } + obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] + obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + return obs + + @property + def space(self) -> spaces.Space: + shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NICS": spaces.Dict( + {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} + ), + "num_file_creations" : spaces.Discrete(4), + "num_file_deletions" : spaces.Discrete(4), + } + return spaces.Dict(shape) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> ServiceObservation: + if parent_where is None: + where = ["network", "nodes", config.hostname] + else: + where = parent_where + ["nodes", config.hostname] + + #pass down shared/common config items + for folder_config in config.folders: + folder_config.include_num_access = config.include_num_access + folder_config.num_files = config.num_files + for nic_config in config.network_interfaces: + nic_config.include_nmne = config.include_nmne + + services = [ServiceObservation.from_config(config=c,parent_where=where) for c in config.services] + applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] + folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] + nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] + + return cls( + where = where, + services = services, + applications = applications, + folders = folders, + network_interfaces = nics, + num_services = config.num_services, + num_applications = config.num_applications, + num_folders = config.num_folders, + num_files = config.num_files, + num_nics = config.num_nics, + include_nmne = config.include_nmne, + include_num_access = config.include_num_access, + ) + + +class PortObservation(AbstractObservation, identifier="PORT"): + class ConfigSchema(AbstractObservation.ConfigSchema): + pass + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +class ACLObservation(AbstractObservation, identifier="ACL"): + class ConfigSchema(AbstractObservation.ConfigSchema): + pass + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +class RouterObservation(AbstractObservation, identifier="ROUTER"): + class ConfigSchema(AbstractObservation.ConfigSchema): + hostname: str + ports: List[PortObservation.ConfigSchema] + + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +class FirewallObservation(AbstractObservation, identifier="FIREWALL"): + class ConfigSchema(AbstractObservation.ConfigSchema): + hostname: str + ports: List[PortObservation.ConfigSchema] = [] + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +class NodesObservation(AbstractObservation, identifier="NODES"): + class ConfigSchema(AbstractObservation.ConfigSchema): + """Config""" + hosts: List[HostObservation.ConfigSchema] = [] + routers: List[RouterObservation.ConfigSchema] = [] + firewalls: List[FirewallObservation.ConfigSchema] = [] + num_services: int = 1 + + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +############################ OLD + +class NodeObservation(AbstractObservation, identifier= "OLD"): """Observation of a node in the network. Includes services, folders and NICs.""" def __init__( diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 400345fa..be90041e 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -2,11 +2,6 @@ from typing import Dict, TYPE_CHECKING from gymnasium.core import ObsType -from primaite.game.agent.observations.agent_observations import ( - UC2BlueObservation, - UC2GreenObservation, - UC2RedObservation, -) from primaite.game.agent.observations.observations import AbstractObservation if TYPE_CHECKING: @@ -63,11 +58,10 @@ class ObservationManager: :param game: Reference to the PrimaiteGame object that spawned this observation. :type game: PrimaiteGame """ - if config["type"] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) - else: - raise ValueError("Observation space type invalid") + + for obs_cfg in config: + obs_type = obs_cfg['type'] + obs_class = AbstractObservation._registry[obs_type] + observation = obs_class.from_config(obs_class.ConfigSchema(**obs_cfg['options'])) + obs_manager = cls(observation) + return obs_manager \ No newline at end of file diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 6236b00d..dc41e8e5 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,9 +1,10 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Type from gymnasium import spaces +from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -17,6 +18,28 @@ if TYPE_CHECKING: class AbstractObservation(ABC): """Abstract class for an observation space component.""" + class ConfigSchema(ABC, BaseModel): + model_config = ConfigDict(extra="forbid") + + _registry: Dict[str, Type["AbstractObservation"]] = {} + """Registry of observation components, with their name as key. + + Automatically populated when subclasses are defined. Used for defining from_config. + """ + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + """ + Register an observation type. + + :param identifier: Identifier used to uniquely specify observation component types. + :type identifier: str + :raises ValueError: When attempting to create a component with a name that is already in use. + """ + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate observation component type {identifier}") + cls._registry[identifier] = cls + @abstractmethod def observe(self, state: Dict) -> Any: """ @@ -36,274 +59,271 @@ class AbstractObservation(ABC): pass @classmethod - @abstractmethod - def from_config(cls, config: Dict, game: "PrimaiteGame"): - """Create this observation space component form a serialised format. - - The `game` parameter is for a the PrimaiteGame object that spawns this component. - """ - pass + def from_config(cls, cfg: Dict) -> "AbstractObservation": + """Create this observation space component form a serialised format.""" + ObservationType = cls._registry[cfg['type']] + return ObservationType.from_config(cfg=cfg) -class LinkObservation(AbstractObservation): - """Observation of a link in the network.""" +# class LinkObservation(AbstractObservation): +# """Observation of a link in the network.""" - default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} - "Default observation is what should be returned when the link doesn't exist." +# default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} +# "Default observation is what should be returned when the link doesn't exist." - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise link observation. +# def __init__(self, where: Optional[Tuple[str]] = None) -> None: +# """Initialise link observation. - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. +# :param where: Store information about where in the simulation state dictionary to find the relevant information. +# Optional. If None, this corresponds that the file does not exist and the observation will be populated with +# zeroes. - A typical location for a service looks like this: - `['network','nodes',,'servics', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where +# A typical location for a service looks like this: +# `['network','nodes',,'servics', ]` +# :type where: Optional[List[str]] +# """ +# super().__init__() +# self.where: Optional[Tuple[str]] = where - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. +# def observe(self, state: Dict) -> Dict: +# """Generate observation based on the current state of the simulation. - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation +# :param state: Simulation state dictionary +# :type state: Dict +# :return: Observation +# :rtype: Dict +# """ +# if self.where is None: +# return self.default_observation - link_state = access_from_nested_dict(state, self.where) - if link_state is NOT_PRESENT_IN_STATE: - return self.default_observation +# link_state = access_from_nested_dict(state, self.where) +# if link_state is NOT_PRESENT_IN_STATE: +# return self.default_observation - bandwidth = link_state["bandwidth"] - load = link_state["current_load"] - if load == 0: - utilisation_category = 0 - else: - utilisation_fraction = load / bandwidth - # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% - utilisation_category = int(utilisation_fraction * 9) + 1 +# bandwidth = link_state["bandwidth"] +# load = link_state["current_load"] +# if load == 0: +# utilisation_category = 0 +# else: +# utilisation_fraction = load / bandwidth +# # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% +# utilisation_category = int(utilisation_fraction * 9) + 1 - # TODO: once the links support separte load per protocol, this needs amendment to reflect that. - return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} +# # TODO: once the links support separte load per protocol, this needs amendment to reflect that. +# return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. +# @property +# def space(self) -> spaces.Space: +# """Gymnasium space object describing the observation space shape. - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) +# :return: Gymnasium space +# :rtype: spaces.Space +# """ +# return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": - """Create link observation from a config. +# @classmethod +# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": +# """Create link observation from a config. - :param config: Dictionary containing the configuration for this link observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed link observation - :rtype: LinkObservation - """ - return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) +# :param config: Dictionary containing the configuration for this link observation. +# :type config: Dict +# :param game: Reference to the PrimaiteGame object that spawned this observation. +# :type game: PrimaiteGame +# :return: Constructed link observation +# :rtype: LinkObservation +# """ +# return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) -class AclObservation(AbstractObservation): - """Observation of an Access Control List (ACL) in the network.""" +# class AclObservation(AbstractObservation): +# """Observation of an Access Control List (ACL) in the network.""" - # TODO: should where be optional, and we can use where=None to pad the observation space? - # definitely the current approach does not support tracking files that aren't specified by name, for example - # if a file is created at runtime, we have currently got no way of telling the observation space to track it. - # this needs adding, but not for the MVP. - def __init__( - self, - node_ip_to_id: Dict[str, int], - ports: List[int], - protocols: List[str], - where: Optional[Tuple[str]] = None, - num_rules: int = 10, - ) -> None: - """Initialise ACL observation. +# # TODO: should where be optional, and we can use where=None to pad the observation space? +# # definitely the current approach does not support tracking files that aren't specified by name, for example +# # if a file is created at runtime, we have currently got no way of telling the observation space to track it. +# # this needs adding, but not for the MVP. +# def __init__( +# self, +# node_ip_to_id: Dict[str, int], +# ports: List[int], +# protocols: List[str], +# where: Optional[Tuple[str]] = None, +# num_rules: int = 10, +# ) -> None: +# """Initialise ACL observation. - :param node_ip_to_id: Mapping between IP address and ID. - :type node_ip_to_id: Dict[str, int] - :param ports: List of ports which are part of the game that define the ordering when converting to an ID - :type ports: List[int] - :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID - :type protocols: list[str] - :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical - example may look like this: - ['network','nodes',,'acl','acl'] - :type where: Optional[Tuple[str]], optional - :param num_rules: , defaults to 10 - :type num_rules: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.num_rules: int = num_rules - self.node_to_id: Dict[str, int] = node_ip_to_id - "List of node IP addresses, order in this list determines how they are converted to an ID" - self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} - "List of ports which are part of the game that define the ordering when converting to an ID" - self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} - "List of protocols which are part of the game, defines ordering when converting to an ID" - self.default_observation: Dict = { - i - + 1: { - "position": i, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - for i in range(self.num_rules) - } +# :param node_ip_to_id: Mapping between IP address and ID. +# :type node_ip_to_id: Dict[str, int] +# :param ports: List of ports which are part of the game that define the ordering when converting to an ID +# :type ports: List[int] +# :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID +# :type protocols: list[str] +# :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical +# example may look like this: +# ['network','nodes',,'acl','acl'] +# :type where: Optional[Tuple[str]], optional +# :param num_rules: , defaults to 10 +# :type num_rules: int, optional +# """ +# super().__init__() +# self.where: Optional[Tuple[str]] = where +# self.num_rules: int = num_rules +# self.node_to_id: Dict[str, int] = node_ip_to_id +# "List of node IP addresses, order in this list determines how they are converted to an ID" +# self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} +# "List of ports which are part of the game that define the ordering when converting to an ID" +# self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} +# "List of protocols which are part of the game, defines ordering when converting to an ID" +# self.default_observation: Dict = { +# i +# + 1: { +# "position": i, +# "permission": 0, +# "source_node_id": 0, +# "source_port": 0, +# "dest_node_id": 0, +# "dest_port": 0, +# "protocol": 0, +# } +# for i in range(self.num_rules) +# } - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. +# def observe(self, state: Dict) -> Dict: +# """Generate observation based on the current state of the simulation. - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - acl_state: Dict = access_from_nested_dict(state, self.where) - if acl_state is NOT_PRESENT_IN_STATE: - return self.default_observation +# :param state: Simulation state dictionary +# :type state: Dict +# :return: Observation +# :rtype: Dict +# """ +# if self.where is None: +# return self.default_observation +# acl_state: Dict = access_from_nested_dict(state, self.where) +# if acl_state is NOT_PRESENT_IN_STATE: +# return self.default_observation - # TODO: what if the ACL has more rules than num of max rules for obs space - obs = {} - acl_items = dict(acl_state.items()) - i = 1 # don't show rule 0 for compatibility reasons. - while i < self.num_rules + 1: - rule_state = acl_items[i] - if rule_state is None: - obs[i] = { - "position": i - 1, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - else: - src_ip = rule_state["src_ip_address"] - src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] - dst_ip = rule_state["dst_ip_address"] - dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] - src_port = rule_state["src_port"] - src_port_id = 1 if src_port is None else self.port_to_id[src_port] - dst_port = rule_state["dst_port"] - dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] - protocol = rule_state["protocol"] - protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] - obs[i] = { - "position": i - 1, - "permission": rule_state["action"], - "source_node_id": src_node_id, - "source_port": src_port_id, - "dest_node_id": dst_node_ip, - "dest_port": dst_port_id, - "protocol": protocol_id, - } - i += 1 - return obs +# # TODO: what if the ACL has more rules than num of max rules for obs space +# obs = {} +# acl_items = dict(acl_state.items()) +# i = 1 # don't show rule 0 for compatibility reasons. +# while i < self.num_rules + 1: +# rule_state = acl_items[i] +# if rule_state is None: +# obs[i] = { +# "position": i - 1, +# "permission": 0, +# "source_node_id": 0, +# "source_port": 0, +# "dest_node_id": 0, +# "dest_port": 0, +# "protocol": 0, +# } +# else: +# src_ip = rule_state["src_ip_address"] +# src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] +# dst_ip = rule_state["dst_ip_address"] +# dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] +# src_port = rule_state["src_port"] +# src_port_id = 1 if src_port is None else self.port_to_id[src_port] +# dst_port = rule_state["dst_port"] +# dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] +# protocol = rule_state["protocol"] +# protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] +# obs[i] = { +# "position": i - 1, +# "permission": rule_state["action"], +# "source_node_id": src_node_id, +# "source_port": src_port_id, +# "dest_node_id": dst_node_ip, +# "dest_port": dst_port_id, +# "protocol": protocol_id, +# } +# i += 1 +# return obs - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. +# @property +# def space(self) -> spaces.Space: +# """Gymnasium space object describing the observation space shape. - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - i - + 1: spaces.Dict( - { - "position": spaces.Discrete(self.num_rules), - "permission": spaces.Discrete(3), - # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "source_port": spaces.Discrete(len(self.port_to_id) + 2), - "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "dest_port": spaces.Discrete(len(self.port_to_id) + 2), - "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), - } - ) - for i in range(self.num_rules) - } - ) +# :return: Gymnasium space +# :rtype: spaces.Space +# """ +# return spaces.Dict( +# { +# i +# + 1: spaces.Dict( +# { +# "position": spaces.Discrete(self.num_rules), +# "permission": spaces.Discrete(3), +# # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) +# "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), +# "source_port": spaces.Discrete(len(self.port_to_id) + 2), +# "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), +# "dest_port": spaces.Discrete(len(self.port_to_id) + 2), +# "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), +# } +# ) +# for i in range(self.num_rules) +# } +# ) - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": - """Generate ACL observation from a config. +# @classmethod +# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": +# """Generate ACL observation from a config. - :param config: Dictionary containing the configuration for this ACL observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Observation object - :rtype: AclObservation - """ - max_acl_rules = config["options"]["max_acl_rules"] - node_ip_to_idx = {} - for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): - node_ref = ip_map_config["node_hostname"] - nic_num = ip_map_config["nic_num"] - node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] - nic_obj = node_obj.network_interface[nic_num] - node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 +# :param config: Dictionary containing the configuration for this ACL observation. +# :type config: Dict +# :param game: Reference to the PrimaiteGame object that spawned this observation. +# :type game: PrimaiteGame +# :return: Observation object +# :rtype: AclObservation +# """ +# max_acl_rules = config["options"]["max_acl_rules"] +# node_ip_to_idx = {} +# for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): +# node_ref = ip_map_config["node_hostname"] +# nic_num = ip_map_config["nic_num"] +# node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] +# nic_obj = node_obj.network_interface[nic_num] +# node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 - router_hostname = config["router_hostname"] - return cls( - node_ip_to_id=node_ip_to_idx, - ports=game.options.ports, - protocols=game.options.protocols, - where=["network", "nodes", router_hostname, "acl", "acl"], - num_rules=max_acl_rules, - ) +# router_hostname = config["router_hostname"] +# return cls( +# node_ip_to_id=node_ip_to_idx, +# ports=game.options.ports, +# protocols=game.options.protocols, +# where=["network", "nodes", router_hostname, "acl", "acl"], +# num_rules=max_acl_rules, +# ) -class NullObservation(AbstractObservation): - """Null observation, returns a single 0 value for the observation space.""" +# class NullObservation(AbstractObservation): +# """Null observation, returns a single 0 value for the observation space.""" - def __init__(self, where: Optional[List[str]] = None): - """Initialise null observation.""" - self.default_observation: Dict = {} +# def __init__(self, where: Optional[List[str]] = None): +# """Initialise null observation.""" +# self.default_observation: Dict = {} - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - return 0 +# def observe(self, state: Dict) -> Dict: +# """Generate observation based on the current state of the simulation.""" +# return 0 - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Discrete(1) +# @property +# def space(self) -> spaces.Space: +# """Gymnasium space object describing the observation space shape.""" +# return spaces.Discrete(1) - @classmethod - def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": - """ - Create null observation from a config. +# @classmethod +# def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": +# """ +# Create null observation from a config. - The parameters are ignored, they are here to match the signature of the other observation classes. - """ - return cls() +# The parameters are ignored, they are here to match the signature of the other observation classes. +# """ +# return cls() -class ICSObservation(NullObservation): - """ICS observation placeholder, currently not implemented so always returns a single 0.""" +# class ICSObservation(NullObservation): +# """ICS observation placeholder, currently not implemented so always returns a single 0.""" - pass +# pass diff --git a/src/primaite/game/agent/utils.py b/src/primaite/game/agent/utils.py index 1314087c..42e8f30b 100644 --- a/src/primaite/game/agent/utils.py +++ b/src/primaite/game/agent/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Hashable, Sequence +from typing import Any, Dict, Hashable, Optional, Sequence NOT_PRESENT_IN_STATE = object() """ @@ -7,7 +7,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is """ -def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any: +def access_from_nested_dict(dictionary: Dict, keys: Optional[Sequence[Hashable]]) -> Any: """ Access an item from a deeply dictionary with a list of keys. @@ -21,6 +21,8 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any: :return: The value in the dictionary :rtype: Any """ + if keys is None: + return NOT_PRESENT_IN_STATE key_list = [*keys] # copy keys to a new list to prevent editing original list if len(key_list) == 0: return dictionary diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index dce05b6a..a5195e1e 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -4,7 +4,7 @@ from uuid import uuid4 import pytest from gymnasium import spaces -from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.game.agent.observations.host import NodeObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation From 0d0b5bc7d9fc64549f19017ded0569bfe9ba45ce Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Wed, 27 Mar 2024 22:11:37 +0000 Subject: [PATCH 03/16] fix previous commit --- src/primaite/game/agent/observations/agent_observations.py | 2 +- .../game_layer/observations/test_node_observations.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py index 2148697b..10370660 100644 --- a/src/primaite/game/agent/observations/agent_observations.py +++ b/src/primaite/game/agent/observations/agent_observations.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces -from primaite.game.agent.observations.host import NodeObservation +from primaite.game.agent.observations.node_observations import NodeObservation from primaite.game.agent.observations.observations import ( AbstractObservation, AclObservation, diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index a5195e1e..dce05b6a 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -4,7 +4,7 @@ from uuid import uuid4 import pytest from gymnasium import spaces -from primaite.game.agent.observations.host import NodeObservation +from primaite.game.agent.observations.node_observations import NodeObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation From f88b4c0f97716ff03344ae22d732252733749c58 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 28 Mar 2024 17:40:27 +0000 Subject: [PATCH 04/16] #2417 more observations --- .../agent/observations/node_observations.py | 539 +++++++++++------- 1 file changed, 322 insertions(+), 217 deletions(-) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 42bdb749..5d46b743 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,4 +1,6 @@ +# TODO: make sure when config options are being passed down from higher-level observations to lower-level, but the lower-level also defines that option, don't overwrite. from __future__ import annotations +from ipaddress import IPv4Address from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union from gymnasium import spaces @@ -163,7 +165,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): } ) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FolderObservation: where = parent_where + ["folders", config.folder_name] #pass down shared/common config items @@ -220,7 +222,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> NICObservation: return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne) @@ -333,7 +335,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): return spaces.Dict(shape) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> ServiceObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> HostObservation: if parent_where is None: where = ["network", "nodes", config.hostname] else: @@ -369,78 +371,282 @@ class HostObservation(AbstractObservation, identifier="HOST"): class PortObservation(AbstractObservation, identifier="PORT"): class ConfigSchema(AbstractObservation.ConfigSchema): - pass + port_id : int def __init__(self, where: WhereType)->None: - pass + self.where = where + self.default_observation: ObsType = {"operating_status" : 0} def observe(self, state: Dict) -> Any: - pass + port_state = access_from_nested_dict(state, self.where) + if port_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"operating_status": 1 if port_state["enabled"] else 2 } @property def space(self) -> spaces.Space: - pass + return spaces.Dict({"operating_status": spaces.Discrete(3)}) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> PortObservation: + return cls(where = parent_where + ["NICs", config.port_id]) class ACLObservation(AbstractObservation, identifier="ACL"): class ConfigSchema(AbstractObservation.ConfigSchema): - pass + ip_list: List[IPv4Address] + port_list: List[int] + protocol_list: List[str] + num_rules: int - def __init__(self, where: WhereType)->None: - pass + def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], port_list: List[int],protocol_list: List[str])->None: + self.where = where + self.num_rules: int = num_rules + self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)} + self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)} + self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)} + self.default_observation: Dict = { + i + + 1: { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + for i in range(self.num_rules) + } def observe(self, state: Dict) -> Any: - pass + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation + obs = {} + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] + if rule_state is None: + obs[i] = { + "position": i - 1, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + else: + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, + "permission": rule_state["action"], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, + } + i += 1 + return obs @property def space(self) -> spaces.Space: - pass + raise NotImplementedError("TODO: need to add wildcard id.") + return spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "source_port": spaces.Discrete(len(self.port_to_id) + 2), + "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), + } + ) + for i in range(self.num_rules) + } + ) + @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ACLObservation: + return cls( + where = parent_where+["acl", "acl"], + num_rules = config.num_rules, + ip_list = config.ip_list, + ports = config.port_list, + protocols = config.protocol_list + ) class RouterObservation(AbstractObservation, identifier="ROUTER"): class ConfigSchema(AbstractObservation.ConfigSchema): hostname: str ports: List[PortObservation.ConfigSchema] + num_ports: int + acl: ACLObservation.ConfigSchema + ip_list: List[str] + port_list: List[int] + protocol_list: List[str] + num_rules: int + def __init__(self, + where: WhereType, + ports:List[PortObservation], + num_ports: int, + acl: ACLObservation, + )->None: + self.where: WhereType = where + self.ports: List[PortObservation] = ports + self.acl: ACLObservation = acl + self.num_ports:int = num_ports - def __init__(self, where: WhereType)->None: - pass + while len(self.ports) < num_ports: + self.ports.append(PortObservation(where=None)) + while len(self.ports) > num_ports: + self.ports.pop() + msg = f"Too many ports in router observation. Truncating." + _LOGGER.warning(msg) + + self.default_observation = { + "PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)}, + "ACL": self.acl.default_observation + } def observe(self, state: Dict) -> Any: - pass + router_state = access_from_nested_dict(state, self.where) + if router_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["PORTS"] = {i+1:p.observe(state) for i,p in enumerate(self.ports)} + obs["ACL"] = self.acl.observe(state) + return obs @property def space(self) -> spaces.Space: - pass + return spaces.Dict({ + "PORTS": {i+1:p.space for i,p in self.ports}, + "ACL": self.acl.space + }) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation: + where = parent_where + ["nodes", config.hostname] + + if config.acl.num_rules is None: + config.acl.num_rules = config.num_rules + if config.acl.ip_list is None: + config.acl.ip_list = config.ip_list + if config.acl.port_list is None: + config.acl.port_list = config.port_list + if config.acl.protocol_list is None: + config.acl.protocol_list = config.protocol_list + + ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] + acl = ACLObservation.from_config(config=config.acl, parent_where=where) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) class FirewallObservation(AbstractObservation, identifier="FIREWALL"): class ConfigSchema(AbstractObservation.ConfigSchema): hostname: str - ports: List[PortObservation.ConfigSchema] = [] + ip_list: List[str] + port_list: List[int] + protocol_list: List[str] + num_rules: int - def __init__(self, where: WhereType)->None: - pass + + def __init__(self, + where: WhereType, + ip_list: List[str], + port_list: List[int], + protocol_list: List[str], + num_rules: int, + )->None: + self.where: WhereType = where + + self.ports: List[PortObservation] = [PortObservation(where=[self.where+["port", port_num]]) for port_num in (1,2,3) ] + #TODO: check what the port nums are for firewall. + + self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + + + self.default_observation = { + "PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)}, + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.default_observation, + "OUTBOUND": self.internal_outbound_acl.default_observation, + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.default_observation, + "OUTBOUND": self.dmz_outbound_acl.default_observation, + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.default_observation, + "OUTBOUND": self.external_outbound_acl.default_observation, + }, + } def observe(self, state: Dict) -> Any: - pass + obs = { + "PORTS": {i+1:p.observe(state) for i,p in enumerate(self.ports)}, + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.observe(state), + "OUTBOUND": self.internal_outbound_acl.observe(state), + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.observe(state), + "OUTBOUND": self.dmz_outbound_acl.observe(state), + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.observe(state), + "OUTBOUND": self.external_outbound_acl.observe(state), + }, + } + return obs @property def space(self) -> spaces.Space: - pass + space =spaces.Dict({ + "PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}), + "INTERNAL": spaces.Dict({ + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + }), + "DMZ": spaces.Dict({ + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + }), + "EXTERNAL": spaces.Dict({ + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + }), + }) + return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation: + where = parent_where+["nodes", config.hostname] + return cls(where=where, ip_list=config.ip_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules) class NodesObservation(AbstractObservation, identifier="NODES"): class ConfigSchema(AbstractObservation.ConfigSchema): @@ -448,205 +654,104 @@ class NodesObservation(AbstractObservation, identifier="NODES"): hosts: List[HostObservation.ConfigSchema] = [] routers: List[RouterObservation.ConfigSchema] = [] firewalls: List[FirewallObservation.ConfigSchema] = [] - num_services: int = 1 + + num_services: int + num_applications: int + num_folders: int + num_files: int + num_nics: int + include_nmne: bool + include_num_access: bool + + ip_list: List[str] + port_list: List[int] + protocol_list: List[str] + num_rules: int - def __init__(self, where: WhereType)->None: - pass + def __init__(self, where: WhereType, hosts:List[HostObservation], routers:List[RouterObservation], firewalls:List[FirewallObservation])->None: + self.where :WhereType = where + + self.hosts: List[HostObservation] = hosts + self.routers: List[RouterObservation] = routers + self.firewalls: List[FirewallObservation] = firewalls + + self.default_observation = { + **{f"HOST{i}":host.default_observation for i,host in enumerate(self.hosts)}, + **{f"ROUTER{i}":router.default_observation for i,router in enumerate(self.routers)}, + **{f"FIREWALL{i}":firewall.default_observation for i,firewall in enumerate(self.firewalls)}, + } def observe(self, state: Dict) -> Any: - pass - - @property - def space(self) -> spaces.Space: - pass - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass - -############################ OLD - -class NodeObservation(AbstractObservation, identifier= "OLD"): - """Observation of a node in the network. Includes services, folders and NICs.""" - - def __init__( - self, - where: Optional[Tuple[str]] = None, - services: List[ServiceObservation] = [], - folders: List[FolderObservation] = [], - network_interfaces: List[NicObservation] = [], - logon_status: bool = False, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> None: - """ - Configurable observation for a node in the simulation. - - :param where: Where in the simulation state dictionary for find relevant information for this observation. - A typical location for a node looks like this: - ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] - :type where: List[str], optional - :param services: Mapping between position in observation space and service name, defaults to {} - :type services: Dict[int,str], optional - :param max_services: Max number of services that can be presented in observation space for this node - , defaults to 2 - :type max_services: int, optional - :param folders: Mapping between position in observation space and folder name, defaults to {} - :type folders: Dict[int,str], optional - :param max_folders: Max number of folders in this node's obs space, defaults to 2 - :type max_folders: int, optional - :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} - :type network_interfaces: Dict[int,str], optional - :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 - :type max_nics: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.services: List[ServiceObservation] = services - while len(self.services) < num_services_per_node: - # add empty service observation without `where` parameter so it always returns default (blank) observation - self.services.append(ServiceObservation()) - while len(self.services) > num_services_per_node: - truncated_service = self.services.pop() - msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warning(msg) - # truncate service list - - self.folders: List[FolderObservation] = folders - # add empty folder observation without `where` parameter that will always return default (blank) observations - while len(self.folders) < num_folders_per_node: - self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) - while len(self.folders) > num_folders_per_node: - truncated_folder = self.folders.pop() - msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" - _LOGGER.warning(msg) - - self.network_interfaces: List[NicObservation] = network_interfaces - while len(self.network_interfaces) < num_nics_per_node: - self.network_interfaces.append(NicObservation()) - while len(self.network_interfaces) > num_nics_per_node: - truncated_nic = self.network_interfaces.pop() - msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" - _LOGGER.warning(msg) - - self.logon_status: bool = logon_status - - self.default_observation: Dict = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, - "operating_status": 0, + obs = { + **{f"HOST{i}":host.observe(state) for i,host in enumerate(self.hosts)}, + **{f"ROUTER{i}":router.observe(state) for i,router in enumerate(self.routers)}, + **{f"FIREWALL{i}":firewall.observe(state) for i,firewall in enumerate(self.firewalls)}, } - if self.logon_status: - self.default_observation["logon_status"] = 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - node_state = access_from_nested_dict(state, self.where) - if node_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - obs["operating_status"] = node_state["operating_state"] - obs["NICS"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } - - if self.logon_status: - obs["logon_status"] = 0 - return obs @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - space_shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(5), - "NICS": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), - } - if self.logon_status: - space_shape["logon_status"] = spaces.Discrete(3) - - return spaces.Dict(space_shape) + space = spaces.Dict({ + **{f"HOST{i}":host.space for i,host in enumerate(self.hosts)}, + **{f"ROUTER{i}":router.space for i,router in enumerate(self.routers)}, + **{f"FIREWALL{i}":firewall.space for i,firewall in enumerate(self.firewalls)}, + }) + return space @classmethod - def from_config( - cls, - config: Dict, - game: "PrimaiteGame", - parent_where: Optional[List[str]] = None, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> "NodeObservation": - """Create node observation from a config. Also creates child service, folder and NIC observations. - - :param config: Dictionary containing the configuration for this node observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this node's parent - network. A typical location for it would be: ['network',] - :type parent_where: Optional[List[str]] - :param num_services_per_node: How many spaces for services are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_services_per_node: int, optional - :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_folders_per_node: int, optional - :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed node observation - :rtype: NodeObservation - """ - node_hostname = config["node_hostname"] + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: if parent_where is None: - where = ["network", "nodes", node_hostname] + where = ["network", "nodes"] else: - where = parent_where + ["nodes", node_hostname] + where = parent_where + ["nodes"] - svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] - folder_configs = config.get("folders", {}) - folders = [ - FolderObservation.from_config( - config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder - ) - for c in folder_configs - ] - # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. - nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] - network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] - logon_status = config.get("logon_status", False) - return cls( - where=where, - services=services, - folders=folders, - network_interfaces=network_interfaces, - logon_status=logon_status, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) + for host_config in config.hosts: + if host_config.num_services is None: + host_config.num_services = config.num_services + if host_config.num_applications is None: + host_config.num_application = config.num_applications + if host_config.num_folders is None: + host_config.num_folder = config.num_folders + if host_config.num_files is None: + host_config.num_file = config.num_files + if host_config.num_nics is None: + host_config.num_nic = config.num_nics + if host_config.include_nmne is None: + host_config.include_nmne = config.include_nmne + if host_config.include_num_access is None: + host_config.include_num_access = config.include_num_access + + for router_config in config.routers: + if router_config.num_ports is None: + router_config.num_ports = config.num_ports + if router_config.ip_list is None: + router_config.ip_list = config.ip_list + + if router_config.port_list is None: + router_config.port_list = config.port_list + + if router_config.protocol_list is None: + router_config.protocol_list = config.protocol_list + + if router_config.num_rules is None: + router_config.num_rules = config.num_rules + + for firewall_config in config.firewalls: + if firewall_config.ip_list is None: + firewall_config.ip_list = config.ip_list + + if firewall_config.port_list is None: + firewall_config.port_list = config.port_list + + if firewall_config.protocol_list is None: + firewall_config.protocol_list = config.protocol_list + + if firewall_config.num_rules is None: + firewall_config.num_rules = config.num_rules + + hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] + routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] + firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls] + + cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls) From d8a66104f50c2c9f41b4522d99bbf4988513ef80 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 29 Mar 2024 11:55:22 +0000 Subject: [PATCH 05/16] Fixed observations --- .../agent/observations/node_observations.py | 173 ++++++++++-------- .../network/hardware/nodes/network/router.py | 2 + 2 files changed, 102 insertions(+), 73 deletions(-) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 5d46b743..b51ea1f2 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -82,7 +82,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): class FileObservation(AbstractObservation, identifier="FILE"): class ConfigSchema(AbstractObservation.ConfigSchema): file_name: str - include_num_access : bool = False + include_num_access: Optional[bool] = None def __init__(self, where: WhereType, include_num_access: bool)->None: self.where: WhereType = where @@ -118,8 +118,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): class ConfigSchema(AbstractObservation.ConfigSchema): folder_name: str files: List[FileObservation.ConfigSchema] = [] - num_files : int = 0 - include_num_access : bool = False + num_files : Optional[int] = None + include_num_access : Optional[bool] = None def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None: self.where: WhereType = where @@ -179,7 +179,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): class ConfigSchema(AbstractObservation.ConfigSchema): nic_num: int - include_nmne: bool = False + include_nmne: Optional[bool] = None def __init__(self, where: WhereType, include_nmne: bool)->None: @@ -233,13 +233,13 @@ class HostObservation(AbstractObservation, identifier="HOST"): applications: List[ApplicationObservation.ConfigSchema] = [] folders: List[FolderObservation.ConfigSchema] = [] network_interfaces: List[NICObservation.ConfigSchema] = [] - num_services: int - num_applications: int - num_folders: int - num_files: int - num_nics: int - include_nmne: bool - include_num_access: bool + num_services: Optional[int] = None + num_applications: Optional[int] = None + num_folders: Optional[int] = None + num_files: Optional[int] = None + num_nics: Optional[int] = None + include_nmne: Optional[bool] = None + include_num_access: Optional[bool] = None def __init__(self, where: WhereType, @@ -296,6 +296,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): self.default_observation: ObsType = { "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)}, "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, "operating_status": 0, @@ -311,6 +312,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): obs = {} obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} obs["operating_status"] = node_state["operating_state"] obs["NICS"] = { @@ -324,6 +326,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): def space(self) -> spaces.Space: shape = { "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}), "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), "operating_status": spaces.Discrete(5), "NICS": spaces.Dict( @@ -393,15 +396,17 @@ class PortObservation(AbstractObservation, identifier="PORT"): class ACLObservation(AbstractObservation, identifier="ACL"): class ConfigSchema(AbstractObservation.ConfigSchema): - ip_list: List[IPv4Address] - port_list: List[int] - protocol_list: List[str] - num_rules: int + ip_list: Optional[List[IPv4Address]] = None + wildcard_list: Optional[List[str]] = None + port_list: Optional[List[int]] = None + protocol_list: Optional[List[str]] = None + num_rules: Optional[int] = None - def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], port_list: List[int],protocol_list: List[str])->None: + def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], wildcard_list: List[str], port_list: List[int],protocol_list: List[str])->None: self.where = where self.num_rules: int = num_rules self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)} + self.wildcard_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(wildcard_list)} self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)} self.default_observation: Dict = { @@ -409,10 +414,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"): + 1: { "position": i, "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, + "source_ip_id": 0, + "source_wildcard_id": 0, + "source_port_id": 0, + "dest_ip_id": 0, + "dest_wildcard_id": 0, + "dest_port_id": 0, "protocol": 0, } for i in range(self.num_rules) @@ -431,30 +438,38 @@ class ACLObservation(AbstractObservation, identifier="ACL"): obs[i] = { "position": i - 1, "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, + "source_ip_id": 0, + "source_wildcard_id": 0, + "source_port_id": 0, + "dest_ip_id": 0, + "dest_wildcard_id": 0, + "dest_port_id": 0, "protocol": 0, } else: src_ip = rule_state["src_ip_address"] - src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + src_node_id = self.ip_to_id.get(src_ip, 1) dst_ip = rule_state["dst_ip_address"] - dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] - src_port = rule_state["src_port"] - src_port_id = 1 if src_port is None else self.port_to_id[src_port] - dst_port = rule_state["dst_port"] - dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + dst_node_ip = self.ip_to_id.get(dst_ip, 1) + src_wildcard = rule_state["source_wildcard_id"] + src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1) + dst_wildcard = rule_state["dest_wildcard_id"] + dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1) + src_port = rule_state["source_port_id"] + src_port_id = self.port_to_id.get(src_port, 1) + dst_port = rule_state["dest_port_id"] + dst_port_id = self.port_to_id.get(dst_port, 1) protocol = rule_state["protocol"] - protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + protocol_id = self.protocol_to_id.get(protocol, 1) obs[i] = { "position": i - 1, "permission": rule_state["action"], - "source_node_id": src_node_id, - "source_port": src_port_id, - "dest_node_id": dst_node_ip, - "dest_port": dst_port_id, + "source_ip_id": src_node_id, + "source_wildcard_id": src_wildcard_id, + "source_port_id": src_port_id, + "dest_ip_id": dst_node_ip, + "dest_wildcard_id": dst_wildcard_id, + "dest_port_id": dst_port_id, "protocol": protocol_id, } i += 1 @@ -462,7 +477,6 @@ class ACLObservation(AbstractObservation, identifier="ACL"): @property def space(self) -> spaces.Space: - raise NotImplementedError("TODO: need to add wildcard id.") return spaces.Dict( { i @@ -471,10 +485,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "position": spaces.Discrete(self.num_rules), "permission": spaces.Discrete(3), # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "source_port": spaces.Discrete(len(self.port_to_id) + 2), - "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), + "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2), + "source_port_id": spaces.Discrete(len(self.port_to_id) + 2), + "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), + "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2), + "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), } ) @@ -489,20 +505,22 @@ class ACLObservation(AbstractObservation, identifier="ACL"): where = parent_where+["acl", "acl"], num_rules = config.num_rules, ip_list = config.ip_list, - ports = config.port_list, - protocols = config.protocol_list + wildcard_list = config.wildcard_list, + port_list = config.port_list, + protocol_list = config.protocol_list ) class RouterObservation(AbstractObservation, identifier="ROUTER"): class ConfigSchema(AbstractObservation.ConfigSchema): hostname: str - ports: List[PortObservation.ConfigSchema] - num_ports: int - acl: ACLObservation.ConfigSchema - ip_list: List[str] - port_list: List[int] - protocol_list: List[str] - num_rules: int + ports: Optional[List[PortObservation.ConfigSchema]] = None + num_ports: Optional[int] = None + acl: Optional[ACLObservation.ConfigSchema] = None + ip_list: Optional[List[str]] = None + wildcard_list: Optional[List[str]] = None + port_list: Optional[List[int]] = None + protocol_list: Optional[List[str]] = None + num_rules: Optional[int] = None def __init__(self, where: WhereType, @@ -540,7 +558,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): @property def space(self) -> spaces.Space: return spaces.Dict({ - "PORTS": {i+1:p.space for i,p in self.ports}, + "PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}), "ACL": self.acl.space }) @@ -548,15 +566,22 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation: where = parent_where + ["nodes", config.hostname] + if config.acl is None: + config.acl = ACLObservation.ConfigSchema() if config.acl.num_rules is None: config.acl.num_rules = config.num_rules if config.acl.ip_list is None: config.acl.ip_list = config.ip_list + if config.acl.wildcard_list is None: + config.acl.wildcard_list = config.wildcard_list if config.acl.port_list is None: config.acl.port_list = config.port_list if config.acl.protocol_list is None: config.acl.protocol_list = config.protocol_list + if config.ports is None: + config.ports = [PortObservation.ConfigSchema(port_id=i+1) for i in range(config.num_ports)] + ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] acl = ACLObservation.from_config(config=config.acl, parent_where=where) return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) @@ -564,30 +589,32 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): class FirewallObservation(AbstractObservation, identifier="FIREWALL"): class ConfigSchema(AbstractObservation.ConfigSchema): hostname: str - ip_list: List[str] - port_list: List[int] - protocol_list: List[str] - num_rules: int + ip_list: Optional[List[str]] = None + wildcard_list: Optional[List[str]] = None + port_list: Optional[List[int]] = None + protocol_list: Optional[List[str]] = None + num_rules: Optional[int] = None def __init__(self, where: WhereType, ip_list: List[str], + wildcard_list: List[str], port_list: List[int], protocol_list: List[str], num_rules: int, )->None: self.where: WhereType = where - self.ports: List[PortObservation] = [PortObservation(where=[self.where+["port", port_num]]) for port_num in (1,2,3) ] + self.ports: List[PortObservation] = [PortObservation(where=self.where+["port", port_num]) for port_num in (1,2,3) ] #TODO: check what the port nums are for firewall. - self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) self.default_observation = { @@ -646,7 +673,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): @classmethod def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation: where = parent_where+["nodes", config.hostname] - return cls(where=where, ip_list=config.ip_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules) + return cls(where=where, ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules) class NodesObservation(AbstractObservation, identifier="NODES"): class ConfigSchema(AbstractObservation.ConfigSchema): @@ -663,7 +690,9 @@ class NodesObservation(AbstractObservation, identifier="NODES"): include_nmne: bool include_num_access: bool + num_ports: int ip_list: List[str] + wildcard_list: List[str] port_list: List[int] protocol_list: List[str] num_rules: int @@ -710,13 +739,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"): if host_config.num_services is None: host_config.num_services = config.num_services if host_config.num_applications is None: - host_config.num_application = config.num_applications + host_config.num_applications = config.num_applications if host_config.num_folders is None: - host_config.num_folder = config.num_folders + host_config.num_folders = config.num_folders if host_config.num_files is None: - host_config.num_file = config.num_files + host_config.num_files = config.num_files if host_config.num_nics is None: - host_config.num_nic = config.num_nics + host_config.num_nics = config.num_nics if host_config.include_nmne is None: host_config.include_nmne = config.include_nmne if host_config.include_num_access is None: @@ -727,26 +756,24 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.num_ports = config.num_ports if router_config.ip_list is None: router_config.ip_list = config.ip_list - + if router_config.wildcard_list is None: + router_config.wildcard_list = config.wildcard_list if router_config.port_list is None: router_config.port_list = config.port_list - if router_config.protocol_list is None: router_config.protocol_list = config.protocol_list - if router_config.num_rules is None: router_config.num_rules = config.num_rules for firewall_config in config.firewalls: if firewall_config.ip_list is None: firewall_config.ip_list = config.ip_list - + if firewall_config.wildcard_list is None: + firewall_config.wildcard_list = config.wildcard_list if firewall_config.port_list is None: firewall_config.port_list = config.port_list - if firewall_config.protocol_list is None: firewall_config.protocol_list = config.protocol_list - if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules @@ -754,4 +781,4 @@ class NodesObservation(AbstractObservation, identifier="NODES"): routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls] - cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls) + return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index d2b47c1a..69ab6a82 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -147,8 +147,10 @@ class ACLRule(SimComponent): state["action"] = self.action.value state["protocol"] = self.protocol.name if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None + state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None state["src_port"] = self.src_port.name if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None + state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None state["dst_port"] = self.dst_port.name if self.dst_port else None state["match_count"] = self.match_count return state From 1751714d3d8f1f5e78a6b97f712765dbd23cc6fd Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 29 Mar 2024 12:21:52 +0000 Subject: [PATCH 06/16] Tidy up node observation file --- .../game/agent/observations/node_observations.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index b51ea1f2..ed930265 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,17 +1,12 @@ -# TODO: make sure when config options are being passed down from higher-level observations to lower-level, but the lower-level also defines that option, don't overwrite. from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, Iterable, List, Optional from gymnasium import spaces from gymnasium.core import ObsType -from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation -# from primaite.game.agent.observations.file_system_observations import FolderObservation -# from primaite.game.agent.observations.nic_observations import NicObservation -# from primaite.game.agent.observations.software_observation import ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) @@ -420,7 +415,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_ip_id": 0, "dest_wildcard_id": 0, "dest_port_id": 0, - "protocol": 0, + "protocol_id": 0, } for i in range(self.num_rules) } @@ -444,7 +439,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_ip_id": 0, "dest_wildcard_id": 0, "dest_port_id": 0, - "protocol": 0, + "protocol_id": 0, } else: src_ip = rule_state["src_ip_address"] @@ -470,7 +465,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_ip_id": dst_node_ip, "dest_wildcard_id": dst_wildcard_id, "dest_port_id": dst_port_id, - "protocol": protocol_id, + "protocol_id": protocol_id, } i += 1 return obs @@ -491,7 +486,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2), "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), - "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), + "protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2), } ) for i in range(self.num_rules) From 9123aff592e952df7f9df5d9257dbbb5c9ef973a Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 29 Mar 2024 13:15:31 +0000 Subject: [PATCH 07/16] #2417 Add hella docstrings --- .../agent/observations/node_observations.py | 997 ++++++++++++++---- 1 file changed, 792 insertions(+), 205 deletions(-) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index ed930265..c702f8e2 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,4 +1,5 @@ from __future__ import annotations + from ipaddress import IPv4Address from typing import Any, Dict, Iterable, List, Optional @@ -15,14 +16,34 @@ WhereType = Iterable[str | int] | None class ServiceObservation(AbstractObservation, identifier="SERVICE"): - class ConfigSchema(AbstractObservation.ConfigSchema): - service_name: str + """Service observation, shows status of a service in the simulation environment.""" - def __init__(self, where: WhereType)->None: + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ServiceObservation.""" + + service_name: str + """Name of the service, used for querying simulation state dictionary""" + + def __init__(self, where: WhereType) -> None: + """ + Initialize a service observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this service. + A typical location for a service might be ['network', 'nodes', , 'services', ]. + :type where: WhereType + """ self.where = where self.default_observation = {"operating_status": 0, "health_status": 0} - def observe(self, state: Dict) -> Any: + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the operating status and health status of the service. + :rtype: Any + """ service_state = access_from_nested_dict(state, self.where) if service_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -33,24 +54,60 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for service status. + :rtype: spaces.Space + """ return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - return cls(where=parent_where+["services", config.service_name]) + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: + """ + Create a service observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the service observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this service's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed service observation instance. + :rtype: ServiceObservation + """ + return cls(where=parent_where + ["services", config.service_name]) class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): - class ConfigSchema(AbstractObservation.ConfigSchema): - application_name: str + """Application observation, shows the status of an application within the simulation environment.""" - def __init__(self, where: WhereType)->None: + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ApplicationObservation.""" + + application_name: str + """Name of the application, used for querying simulation state dictionary""" + + def __init__(self, where: WhereType) -> None: + """ + Initialise an application observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this application. + A typical location for an application might be + ['network', 'nodes', , 'applications', ]. + :type where: WhereType + """ self.where = where self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} def observe(self, state: Dict) -> Any: - # raise NotImplementedError("TODO NUM EXECUTIONS NEEDS TO BE CONVERTED TO A CATEGORICAL") + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Obs containing the operating status, health status, and number of executions of the application. + :rtype: Any + """ application_state = access_from_nested_dict(state, self.where) if application_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -62,32 +119,74 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({ - "operating_status": spaces.Discrete(7), - "health_status": spaces.Discrete(5), - "num_executions": spaces.Discrete(4) - }) + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for application status. + :rtype: spaces.Space + """ + return spaces.Dict( + { + "operating_status": spaces.Discrete(7), + "health_status": spaces.Discrete(5), + "num_executions": spaces.Discrete(4), + } + ) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ApplicationObservation: - return cls(where=parent_where+["applications", config.application_name]) + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation: + """ + Create an application observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the application observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this application's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed application observation instance. + :rtype: ApplicationObservation + """ + return cls(where=parent_where + ["applications", config.application_name]) class FileObservation(AbstractObservation, identifier="FILE"): - class ConfigSchema(AbstractObservation.ConfigSchema): - file_name: str - include_num_access: Optional[bool] = None + """File observation, provides status information about a file within the simulation environment.""" - def __init__(self, where: WhereType, include_num_access: bool)->None: + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FileObservation.""" + + file_name: str + """Name of the file, used for querying simulation state dictionary.""" + include_num_access: Optional[bool] = None + """Whether to include the number of accesses to the file in the observation.""" + + def __init__(self, where: WhereType, include_num_access: bool) -> None: + """ + Initialize a file observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this file. + A typical location for a file might be + ['network', 'nodes', , 'file_system', 'folder', , 'files', ]. + :type where: WhereType + :param include_num_access: Whether to include the number of accesses to the file in the observation. + :type include_num_access: bool + """ self.where: WhereType = where - self.include_num_access :bool = include_num_access + self.include_num_access: bool = include_num_access self.default_observation: ObsType = {"health_status": 0} if self.include_num_access: self.default_observation["num_access"] = 0 def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the health status of the file and optionally the number of accesses. + :rtype: Any + """ file_state = access_from_nested_dict(state, self.where) if file_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -99,29 +198,69 @@ class FileObservation(AbstractObservation, identifier="FILE"): @property def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for file status. + :rtype: spaces.Space + """ space = {"health_status": spaces.Discrete(6)} if self.include_num_access: space["num_access"] = spaces.Discrete(4) return spaces.Dict(space) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation: - return cls(where=parent_where+["files", config.file_name], include_num_access=config.include_num_access) + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation: + """ + Create a file observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the file observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this file's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed file observation instance. + :rtype: FileObservation + """ + return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access) class FolderObservation(AbstractObservation, identifier="FOLDER"): - class ConfigSchema(AbstractObservation.ConfigSchema): - folder_name: str - files: List[FileObservation.ConfigSchema] = [] - num_files : Optional[int] = None - include_num_access : Optional[bool] = None + """Folder observation, provides status information about a folder within the simulation environment.""" - def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None: + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FolderObservation.""" + + folder_name: str + """Name of the folder, used for querying simulation state dictionary.""" + files: List[FileObservation.ConfigSchema] = [] + """List of file configurations within the folder.""" + num_files: Optional[int] = None + """Number of spaces for file observations in this folder.""" + include_num_access: Optional[bool] = None + """Whether files in this folder should include the number of accesses in their observation.""" + + def __init__( + self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool + ) -> None: + """ + Initialize a folder observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this folder. + A typical location for a folder might be ['network', 'nodes', , 'folders', ]. + :type where: WhereType + :param files: List of file observation instances within the folder. + :type files: Iterable[FileObservation] + :param num_files: Number of files expected in the folder. + :type num_files: int + :param include_num_access: Whether to include the number of accesses to files in the observation. + :type include_num_access: bool + """ self.where: WhereType = where self.files: List[FileObservation] = files while len(self.files) < num_files: - self.files.append(FileObservation(where=None,include_num_access=include_num_access)) + self.files.append(FileObservation(where=None, include_num_access=include_num_access)) while len(self.files) > num_files: truncated_file = self.files.pop() msg = f"Too many files in folder observation. Truncating file {truncated_file}" @@ -133,6 +272,14 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): } def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the health status of the folder and status of files within the folder. + :rtype: Any + """ folder_state = access_from_nested_dict(state, self.where) if folder_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -148,9 +295,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. + """ + Gymnasium space object describing the observation space shape. - :return: Gymnasium space + :return: Gymnasium space representing the observation space for folder status. :rtype: spaces.Space """ return spaces.Dict( @@ -159,34 +307,68 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), } ) + @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FolderObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation: + """ + Create a folder observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the folder observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this folder's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed folder observation instance. + :rtype: FolderObservation + """ where = parent_where + ["folders", config.folder_name] - #pass down shared/common config items + # pass down shared/common config items for file_config in config.files: file_config.include_num_access = config.include_num_access - files = [FileObservation.from_config(config=f, parent_where = where) for f in config.files] + files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): + """Status information about a network interface within the simulation environment.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for NICObservation.""" + nic_num: int + """Number of the network interface.""" include_nmne: Optional[bool] = None + """Whether to include number of malicious network events (NMNE) in the observation.""" + def __init__(self, where: WhereType, include_nmne: bool) -> None: + """ + Initialize a network interface observation instance. - def __init__(self, where: WhereType, include_nmne: bool)->None: + :param where: Where in the simulation state dictionary to find the relevant information for this interface. + A typical location for a network interface might be + ['network', 'nodes', , 'NICs', ]. + :type where: WhereType + :param include_nmne: Flag to determine whether to include NMNE information in the observation. + :type include_nmne: bool + """ self.where = where - self.include_nmne : bool = include_nmne + self.include_nmne: bool = include_nmne self.default_observation: ObsType = {"nic_status": 0} if self.include_nmne: - self.default_observation.update({"NMNE":{"inbound":0, "outbound":0}}) + self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}}) def observe(self, state: Dict) -> Any: - # raise NotImplementedError("TODO: CATEGORISATION") + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status of the network interface and optionally NMNE information. + :rtype: Any + """ nic_state = access_from_nested_dict(state, self.where) if nic_state is NOT_PRESENT_IN_STATE: @@ -206,9 +388,14 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): self.nmne_outbound_last_step = outbound_count return obs - @property def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for network interface status and NMNE information. + :rtype: spaces.Space + """ space = spaces.Dict({"nic_status": spaces.Discrete(3)}) if self.include_nmne: @@ -217,43 +404,99 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> NICObservation: - return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne) + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: + """ + Create a network interface observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the network interface observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this NIC's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed network interface observation instance. + :rtype: NICObservation + """ + return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne) class HostObservation(AbstractObservation, identifier="HOST"): + """Host observation, provides status information about a host within the simulation environment.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for HostObservation.""" + hostname: str + """Hostname of the host, used for querying simulation state dictionary.""" services: List[ServiceObservation.ConfigSchema] = [] + """List of services to observe on the host.""" applications: List[ApplicationObservation.ConfigSchema] = [] + """List of applications to observe on the host.""" folders: List[FolderObservation.ConfigSchema] = [] + """List of folders to observe on the host.""" network_interfaces: List[NICObservation.ConfigSchema] = [] + """List of network interfaces to observe on the host.""" num_services: Optional[int] = None + """Number of spaces for service observations on this host.""" num_applications: Optional[int] = None + """Number of spaces for application observations on this host.""" num_folders: Optional[int] = None + """Number of spaces for folder observations on this host.""" num_files: Optional[int] = None + """Number of spaces for file observations on this host.""" num_nics: Optional[int] = None + """Number of spaces for network interface observations on this host.""" include_nmne: Optional[bool] = None + """Whether network interface observations should include number of malicious network events.""" include_num_access: Optional[bool] = None + """Whether to include the number of accesses to files observations on this host.""" - def __init__(self, - where: WhereType, - services:List[ServiceObservation], - applications:List[ApplicationObservation], - folders:List[FolderObservation], - network_interfaces:List[NICObservation], - num_services: int, - num_applications: int, - num_folders: int, - num_files: int, - num_nics: int, - include_nmne: bool, - include_num_access: bool - )->None: + def __init__( + self, + where: WhereType, + services: List[ServiceObservation], + applications: List[ApplicationObservation], + folders: List[FolderObservation], + network_interfaces: List[NICObservation], + num_services: int, + num_applications: int, + num_folders: int, + num_files: int, + num_nics: int, + include_nmne: bool, + include_num_access: bool, + ) -> None: + """ + Initialize a host observation instance. - self.where : WhereType = where + :param where: Where in the simulation state dictionary to find the relevant information for this host. + A typical location for a host might be ['network', 'nodes', ]. + :type where: WhereType + :param services: List of service observations on the host. + :type services: List[ServiceObservation] + :param applications: List of application observations on the host. + :type applications: List[ApplicationObservation] + :param folders: List of folder observations on the host. + :type folders: List[FolderObservation] + :param network_interfaces: List of network interface observations on the host. + :type network_interfaces: List[NICObservation] + :param num_services: Number of services to observe. + :type num_services: int + :param num_applications: Number of applications to observe. + :type num_applications: int + :param num_folders: Number of folders to observe. + :type num_folders: int + :param num_files: Number of files. + :type num_files: int + :param num_nics: Number of network interfaces. + :type num_nics: int + :param include_nmne: Flag to include network metrics and errors. + :type include_nmne: bool + :param include_num_access: Flag to include the number of accesses to files. + :type include_num_access: bool + """ + self.where: WhereType = where - # ensure service list has length equal to num_services by truncating or padding + # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services while len(self.services) < num_services: self.services.append(ServiceObservation(where=None)) @@ -262,31 +505,30 @@ class HostObservation(AbstractObservation, identifier="HOST"): msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" _LOGGER.warning(msg) - # ensure application list has length equal to num_applications by truncating or padding self.applications: List[ApplicationObservation] = applications while len(self.applications) < num_applications: self.applications.append(ApplicationObservation(where=None)) while len(self.applications) > num_applications: truncated_application = self.applications.pop() - msg = f"Too many applications in Node observation space for node. Truncating application {truncated_application.where}" + msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" _LOGGER.warning(msg) - # ensure folder list has length equal to num_folders by truncating or padding self.folders: List[FolderObservation] = folders while len(self.folders) < num_folders: - self.folders.append(FolderObservation(where = None, files= [], num_files=num_files, include_num_access=include_num_access)) + self.folders.append( + FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access) + ) while len(self.folders) > num_folders: truncated_folder = self.folders.pop() msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}" _LOGGER.warning(msg) - # ensure network_interface list has length equal to num_network_interfaces by truncating or padding self.network_interfaces: List[NICObservation] = network_interfaces while len(self.network_interfaces) < num_nics: - self.network_interfaces.append(NICObservation(where = None, include_nmne=include_nmne)) + self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne)) while len(self.network_interfaces) > num_nics: truncated_nic = self.network_interfaces.pop() - msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_folder.where}" + msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" _LOGGER.warning(msg) self.default_observation: ObsType = { @@ -299,8 +541,15 @@ class HostObservation(AbstractObservation, identifier="HOST"): "num_file_deletions": 0, } - def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status information about the host. + :rtype: Any + """ node_state = access_from_nested_dict(state, self.where) if node_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -319,6 +568,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): @property def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for host status. + :rtype: spaces.Space + """ shape = { "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), "APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}), @@ -327,83 +582,165 @@ class HostObservation(AbstractObservation, identifier="HOST"): "NICS": spaces.Dict( {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} ), - "num_file_creations" : spaces.Discrete(4), - "num_file_deletions" : spaces.Discrete(4), + "num_file_creations": spaces.Discrete(4), + "num_file_deletions": spaces.Discrete(4), } return spaces.Dict(shape) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> HostObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation: + """ + Create a host observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the host observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this host. + A typical location might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed host observation instance. + :rtype: HostObservation + """ if parent_where is None: where = ["network", "nodes", config.hostname] else: where = parent_where + ["nodes", config.hostname] - #pass down shared/common config items + # Pass down shared/common config items for folder_config in config.folders: folder_config.include_num_access = config.include_num_access folder_config.num_files = config.num_files for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne - services = [ServiceObservation.from_config(config=c,parent_where=where) for c in config.services] + services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] return cls( - where = where, - services = services, - applications = applications, - folders = folders, - network_interfaces = nics, - num_services = config.num_services, - num_applications = config.num_applications, - num_folders = config.num_folders, - num_files = config.num_files, - num_nics = config.num_nics, - include_nmne = config.include_nmne, - include_num_access = config.include_num_access, + where=where, + services=services, + applications=applications, + folders=folders, + network_interfaces=nics, + num_services=config.num_services, + num_applications=config.num_applications, + num_folders=config.num_folders, + num_files=config.num_files, + num_nics=config.num_nics, + include_nmne=config.include_nmne, + include_num_access=config.include_num_access, ) class PortObservation(AbstractObservation, identifier="PORT"): - class ConfigSchema(AbstractObservation.ConfigSchema): - port_id : int + """Port observation, provides status information about a network port within the simulation environment.""" - def __init__(self, where: WhereType)->None: + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for PortObservation.""" + + port_id: int + """Identifier of the port, used for querying simulation state dictionary.""" + + def __init__(self, where: WhereType) -> None: + """ + Initialize a port observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this port. + A typical location for a port might be ['network', 'nodes', , 'NICs', ]. + :type where: WhereType + """ self.where = where - self.default_observation: ObsType = {"operating_status" : 0} + self.default_observation: ObsType = {"operating_status": 0} def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the operating status of the port. + :rtype: Any + """ port_state = access_from_nested_dict(state, self.where) if port_state is NOT_PRESENT_IN_STATE: return self.default_observation - return {"operating_status": 1 if port_state["enabled"] else 2 } + return {"operating_status": 1 if port_state["enabled"] else 2} @property def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for port status. + :rtype: spaces.Space + """ return spaces.Dict({"operating_status": spaces.Discrete(3)}) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> PortObservation: - return cls(where = parent_where + ["NICs", config.port_id]) + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation: + """ + Create a port observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the port observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this port's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed port observation instance. + :rtype: PortObservation + """ + return cls(where=parent_where + ["NICs", config.port_id]) + class ACLObservation(AbstractObservation, identifier="ACL"): - class ConfigSchema(AbstractObservation.ConfigSchema): - ip_list: Optional[List[IPv4Address]] = None - wildcard_list: Optional[List[str]] = None - port_list: Optional[List[int]] = None - protocol_list: Optional[List[str]] = None - num_rules: Optional[int] = None + """ACL observation, provides information about access control lists within the simulation environment.""" - def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], wildcard_list: List[str], port_list: List[int],protocol_list: List[str])->None: + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ACLObservation.""" + + ip_list: Optional[List[IPv4Address]] = None + """List of IP addresses.""" + wildcard_list: Optional[List[str]] = None + """List of wildcard strings.""" + port_list: Optional[List[int]] = None + """List of port numbers.""" + protocol_list: Optional[List[str]] = None + """List of protocol names.""" + num_rules: Optional[int] = None + """Number of ACL rules.""" + + def __init__( + self, + where: WhereType, + num_rules: int, + ip_list: List[IPv4Address], + wildcard_list: List[str], + port_list: List[int], + protocol_list: List[str], + ) -> None: + """ + Initialize an ACL observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this ACL. + :type where: WhereType + :param num_rules: Number of ACL rules. + :type num_rules: int + :param ip_list: List of IP addresses. + :type ip_list: List[IPv4Address] + :param wildcard_list: List of wildcard strings. + :type wildcard_list: List[str] + :param port_list: List of port numbers. + :type port_list: List[int] + :param protocol_list: List of protocol names. + :type protocol_list: List[str] + """ self.where = where self.num_rules: int = num_rules - self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)} - self.wildcard_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(wildcard_list)} - self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)} - self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)} + self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)} + self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)} + self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)} + self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i + 1: { @@ -421,6 +758,14 @@ class ACLObservation(AbstractObservation, identifier="ACL"): } def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing ACL rules. + :rtype: Any + """ acl_state: Dict = access_from_nested_dict(state, self.where) if acl_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -472,6 +817,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"): @property def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for ACL rules. + :rtype: spaces.Space + """ return spaces.Dict( { i @@ -481,10 +832,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "permission": spaces.Discrete(3), # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) "source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), - "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2), + "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), "source_port_id": spaces.Discrete(len(self.port_to_id) + 2), "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), - "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2), + "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), "protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2), } @@ -493,72 +844,134 @@ class ACLObservation(AbstractObservation, identifier="ACL"): } ) - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ACLObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation: + """ + Create an ACL observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the ACL observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this ACL's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed ACL observation instance. + :rtype: ACLObservation + """ return cls( - where = parent_where+["acl", "acl"], - num_rules = config.num_rules, - ip_list = config.ip_list, - wildcard_list = config.wildcard_list, - port_list = config.port_list, - protocol_list = config.protocol_list - ) + where=parent_where + ["acl", "acl"], + num_rules=config.num_rules, + ip_list=config.ip_list, + wildcard_list=config.wildcard_list, + port_list=config.port_list, + protocol_list=config.protocol_list, + ) + class RouterObservation(AbstractObservation, identifier="ROUTER"): - class ConfigSchema(AbstractObservation.ConfigSchema): - hostname: str - ports: Optional[List[PortObservation.ConfigSchema]] = None - num_ports: Optional[int] = None - acl: Optional[ACLObservation.ConfigSchema] = None - ip_list: Optional[List[str]] = None - wildcard_list: Optional[List[str]] = None - port_list: Optional[List[int]] = None - protocol_list: Optional[List[str]] = None - num_rules: Optional[int] = None + """Router observation, provides status information about a router within the simulation environment.""" - def __init__(self, - where: WhereType, - ports:List[PortObservation], - num_ports: int, - acl: ACLObservation, - )->None: + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for RouterObservation.""" + + hostname: str + """Hostname of the router, used for querying simulation state dictionary.""" + ports: Optional[List[PortObservation.ConfigSchema]] = None + """Configuration of port observations for this router.""" + num_ports: Optional[int] = None + """Number of port observations configured for this router.""" + acl: Optional[ACLObservation.ConfigSchema] = None + """Configuration of ACL observation on this router.""" + ip_list: Optional[List[str]] = None + """List of IP addresses for encoding ACLs.""" + wildcard_list: Optional[List[str]] = None + """List of IP wildcards for encoding ACLs.""" + port_list: Optional[List[int]] = None + """List of ports for encoding ACLs.""" + protocol_list: Optional[List[str]] = None + """List of protocols for encoding ACLs.""" + num_rules: Optional[int] = None + """Number of rules ACL rules to show.""" + + def __init__( + self, + where: WhereType, + ports: List[PortObservation], + num_ports: int, + acl: ACLObservation, + ) -> None: + """ + Initialize a router observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this router. + A typical location for a router might be ['network', 'nodes', ]. + :type where: WhereType + :param ports: List of port observations representing the ports of the router. + :type ports: List[PortObservation] + :param num_ports: Number of ports for the router. + :type num_ports: int + :param acl: ACL observation representing the access control list of the router. + :type acl: ACLObservation + """ self.where: WhereType = where self.ports: List[PortObservation] = ports self.acl: ACLObservation = acl - self.num_ports:int = num_ports + self.num_ports: int = num_ports while len(self.ports) < num_ports: self.ports.append(PortObservation(where=None)) while len(self.ports) > num_ports: self.ports.pop() - msg = f"Too many ports in router observation. Truncating." + msg = "Too many ports in router observation. Truncating." _LOGGER.warning(msg) self.default_observation = { - "PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)}, - "ACL": self.acl.default_observation - } + "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, + "ACL": self.acl.default_observation, + } def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status of ports and ACL configuration of the router. + :rtype: Any + """ router_state = access_from_nested_dict(state, self.where) if router_state is NOT_PRESENT_IN_STATE: return self.default_observation obs = {} - obs["PORTS"] = {i+1:p.observe(state) for i,p in enumerate(self.ports)} + obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} obs["ACL"] = self.acl.observe(state) return obs @property def space(self) -> spaces.Space: - return spaces.Dict({ - "PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}), - "ACL": self.acl.space - }) + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for router status. + :rtype: spaces.Space + """ + return spaces.Dict( + {"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space} + ) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation: + """ + Create a router observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the router observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this router's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed router observation instance. + :rtype: RouterObservation + """ where = parent_where + ["nodes", config.hostname] if config.acl is None: @@ -575,156 +988,330 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): config.acl.protocol_list = config.protocol_list if config.ports is None: - config.ports = [PortObservation.ConfigSchema(port_id=i+1) for i in range(config.num_ports)] + config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)] ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] acl = ACLObservation.from_config(config=config.acl, parent_where=where) return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) + class FirewallObservation(AbstractObservation, identifier="FIREWALL"): + """Firewall observation, provides status information about a firewall within the simulation environment.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FirewallObservation.""" + hostname: str + """Hostname of the firewall node, used for querying simulation state dictionary.""" ip_list: Optional[List[str]] = None + """List of IP addresses for encoding ACLs.""" wildcard_list: Optional[List[str]] = None + """List of IP wildcards for encoding ACLs.""" port_list: Optional[List[int]] = None + """List of ports for encoding ACLs.""" protocol_list: Optional[List[str]] = None + """List of protocols for encoding ACLs.""" num_rules: Optional[int] = None + """Number of rules ACL rules to show.""" + def __init__( + self, + where: WhereType, + ip_list: List[str], + wildcard_list: List[str], + port_list: List[int], + protocol_list: List[str], + num_rules: int, + ) -> None: + """ + Initialize a firewall observation instance. - def __init__(self, - where: WhereType, - ip_list: List[str], - wildcard_list: List[str], - port_list: List[int], - protocol_list: List[str], - num_rules: int, - )->None: + :param where: Where in the simulation state dictionary to find the relevant information for this firewall. + A typical location for a firewall might be ['network', 'nodes', ]. + :type where: WhereType + :param ip_list: List of IP addresses. + :type ip_list: List[str] + :param wildcard_list: List of wildcard rules. + :type wildcard_list: List[str] + :param port_list: List of port numbers. + :type port_list: List[int] + :param protocol_list: List of protocol types. + :type protocol_list: List[str] + :param num_rules: Number of rules configured in the firewall. + :type num_rules: int + """ self.where: WhereType = where - self.ports: List[PortObservation] = [PortObservation(where=self.where+["port", port_num]) for port_num in (1,2,3) ] - #TODO: check what the port nums are for firewall. - - self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) - self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) - self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) - self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) - self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) - self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.ports: List[PortObservation] = [ + PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3) + ] + # TODO: check what the port nums are for firewall. + self.internal_inbound_acl = ACLObservation( + where=self.where + ["acl", "internal", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.internal_outbound_acl = ACLObservation( + where=self.where + ["acl", "internal", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.dmz_inbound_acl = ACLObservation( + where=self.where + ["acl", "dmz", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.dmz_outbound_acl = ACLObservation( + where=self.where + ["acl", "dmz", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.external_inbound_acl = ACLObservation( + where=self.where + ["acl", "external", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.external_outbound_acl = ACLObservation( + where=self.where + ["acl", "external", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) self.default_observation = { - "PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)}, + "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, "INTERNAL": { "INBOUND": self.internal_inbound_acl.default_observation, "OUTBOUND": self.internal_outbound_acl.default_observation, - }, + }, "DMZ": { "INBOUND": self.dmz_inbound_acl.default_observation, "OUTBOUND": self.dmz_outbound_acl.default_observation, - }, + }, "EXTERNAL": { "INBOUND": self.external_inbound_acl.default_observation, "OUTBOUND": self.external_outbound_acl.default_observation, - }, - } + }, + } def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. + :rtype: Any + """ obs = { - "PORTS": {i+1:p.observe(state) for i,p in enumerate(self.ports)}, + "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, "INTERNAL": { "INBOUND": self.internal_inbound_acl.observe(state), "OUTBOUND": self.internal_outbound_acl.observe(state), - }, + }, "DMZ": { "INBOUND": self.dmz_inbound_acl.observe(state), "OUTBOUND": self.dmz_outbound_acl.observe(state), - }, + }, "EXTERNAL": { "INBOUND": self.external_inbound_acl.observe(state), "OUTBOUND": self.external_outbound_acl.observe(state), - }, - } + }, + } return obs @property def space(self) -> spaces.Space: - space =spaces.Dict({ - "PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}), - "INTERNAL": spaces.Dict({ - "INBOUND": self.internal_inbound_acl.space, - "OUTBOUND": self.internal_outbound_acl.space, - }), - "DMZ": spaces.Dict({ - "INBOUND": self.dmz_inbound_acl.space, - "OUTBOUND": self.dmz_outbound_acl.space, - }), - "EXTERNAL": spaces.Dict({ - "INBOUND": self.external_inbound_acl.space, - "OUTBOUND": self.external_outbound_acl.space, - }), - }) + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for firewall status. + :rtype: spaces.Space + """ + space = spaces.Dict( + { + "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), + "INTERNAL": spaces.Dict( + { + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + } + ), + "DMZ": spaces.Dict( + { + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + } + ), + "EXTERNAL": spaces.Dict( + { + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + } + ), + } + ) return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation: - where = parent_where+["nodes", config.hostname] - return cls(where=where, ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules) + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: + """ + Create a firewall observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the firewall observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this firewall's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed firewall observation instance. + :rtype: FirewallObservation + """ + where = parent_where + ["nodes", config.hostname] + return cls( + where=where, + ip_list=config.ip_list, + wildcard_list=config.wildcard_list, + port_list=config.port_list, + protocol_list=config.protocol_list, + num_rules=config.num_rules, + ) + class NodesObservation(AbstractObservation, identifier="NODES"): + """Nodes observation, provides status information about nodes within the simulation environment.""" + class ConfigSchema(AbstractObservation.ConfigSchema): - """Config""" + """Configuration schema for NodesObservation.""" + hosts: List[HostObservation.ConfigSchema] = [] + """List of configurations for host observations.""" routers: List[RouterObservation.ConfigSchema] = [] + """List of configurations for router observations.""" firewalls: List[FirewallObservation.ConfigSchema] = [] - + """List of configurations for firewall observations.""" num_services: int + """Number of services.""" num_applications: int + """Number of applications.""" num_folders: int + """Number of folders.""" num_files: int + """Number of files.""" num_nics: int + """Number of network interface cards (NICs).""" include_nmne: bool + """Flag to include nmne.""" include_num_access: bool - + """Flag to include the number of accesses.""" num_ports: int + """Number of ports.""" ip_list: List[str] + """List of IP addresses for encoding ACLs.""" wildcard_list: List[str] + """List of IP wildcards for encoding ACLs.""" port_list: List[int] + """List of ports for encoding ACLs.""" protocol_list: List[str] + """List of protocols for encoding ACLs.""" num_rules: int + """Number of rules ACL rules to show.""" + def __init__( + self, + where: WhereType, + hosts: List[HostObservation], + routers: List[RouterObservation], + firewalls: List[FirewallObservation], + ) -> None: + """ + Initialize a nodes observation instance. - def __init__(self, where: WhereType, hosts:List[HostObservation], routers:List[RouterObservation], firewalls:List[FirewallObservation])->None: - self.where :WhereType = where + :param where: Where in the simulation state dictionary to find the relevant information for nodes. + A typical location for nodes might be ['network', 'nodes']. + :type where: WhereType + :param hosts: List of host observations. + :type hosts: List[HostObservation] + :param routers: List of router observations. + :type routers: List[RouterObservation] + :param firewalls: List of firewall observations. + :type firewalls: List[FirewallObservation] + """ + self.where: WhereType = where self.hosts: List[HostObservation] = hosts self.routers: List[RouterObservation] = routers self.firewalls: List[FirewallObservation] = firewalls self.default_observation = { - **{f"HOST{i}":host.default_observation for i,host in enumerate(self.hosts)}, - **{f"ROUTER{i}":router.default_observation for i,router in enumerate(self.routers)}, - **{f"FIREWALL{i}":firewall.default_observation for i,firewall in enumerate(self.firewalls)}, + **{f"HOST{i}": host.default_observation for i, host in enumerate(self.hosts)}, + **{f"ROUTER{i}": router.default_observation for i, router in enumerate(self.routers)}, + **{f"FIREWALL{i}": firewall.default_observation for i, firewall in enumerate(self.firewalls)}, } def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing status information about nodes. + :rtype: Any + """ obs = { - **{f"HOST{i}":host.observe(state) for i,host in enumerate(self.hosts)}, - **{f"ROUTER{i}":router.observe(state) for i,router in enumerate(self.routers)}, - **{f"FIREWALL{i}":firewall.observe(state) for i,firewall in enumerate(self.firewalls)}, + **{f"HOST{i}": host.observe(state) for i, host in enumerate(self.hosts)}, + **{f"ROUTER{i}": router.observe(state) for i, router in enumerate(self.routers)}, + **{f"FIREWALL{i}": firewall.observe(state) for i, firewall in enumerate(self.firewalls)}, } return obs @property def space(self) -> spaces.Space: - space = spaces.Dict({ - **{f"HOST{i}":host.space for i,host in enumerate(self.hosts)}, - **{f"ROUTER{i}":router.space for i,router in enumerate(self.routers)}, - **{f"FIREWALL{i}":firewall.space for i,firewall in enumerate(self.firewalls)}, - }) + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for nodes. + :rtype: spaces.Space + """ + space = spaces.Dict( + { + **{f"HOST{i}": host.space for i, host in enumerate(self.hosts)}, + **{f"ROUTER{i}": router.space for i, router in enumerate(self.routers)}, + **{f"FIREWALL{i}": firewall.space for i, firewall in enumerate(self.firewalls)}, + } + ) return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: + """ + Create a nodes observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for nodes observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about nodes. + A typical location for nodes might be ['network', 'nodes']. + :type parent_where: WhereType, optional + :return: Constructed nodes observation instance. + :rtype: NodesObservation + """ if parent_where is None: where = ["network", "nodes"] else: From 22e1dfea2f4d92a812378e794c3cadd9c926cb50 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Fri, 29 Mar 2024 14:14:03 +0000 Subject: [PATCH 08/16] #2417 Move classes to correct files --- .../agent/observations/acl_observation.py | 187 +++ .../agent/observations/agent_observations.py | 138 -- .../observations/file_system_observations.py | 207 +-- .../observations/firewall_observation.py | 213 +++ .../agent/observations/host_observations.py | 229 ++++ .../agent/observations/nic_observations.py | 273 ++-- .../agent/observations/node_observations.py | 1197 +---------------- .../game/agent/observations/observations.py | 467 +++---- .../agent/observations/router_observation.py | 142 ++ .../observations/software_observation.py | 192 ++- 10 files changed, 1332 insertions(+), 1913 deletions(-) create mode 100644 src/primaite/game/agent/observations/acl_observation.py delete mode 100644 src/primaite/game/agent/observations/agent_observations.py create mode 100644 src/primaite/game/agent/observations/firewall_observation.py create mode 100644 src/primaite/game/agent/observations/host_observations.py create mode 100644 src/primaite/game/agent/observations/router_observation.py diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py new file mode 100644 index 00000000..2d29223d --- /dev/null +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from ipaddress import IPv4Address +from typing import Dict, List, Optional + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + + +class ACLObservation(AbstractObservation, identifier="ACL"): + """ACL observation, provides information about access control lists within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ACLObservation.""" + + ip_list: Optional[List[IPv4Address]] = None + """List of IP addresses.""" + wildcard_list: Optional[List[str]] = None + """List of wildcard strings.""" + port_list: Optional[List[int]] = None + """List of port numbers.""" + protocol_list: Optional[List[str]] = None + """List of protocol names.""" + num_rules: Optional[int] = None + """Number of ACL rules.""" + + def __init__( + self, + where: WhereType, + num_rules: int, + ip_list: List[IPv4Address], + wildcard_list: List[str], + port_list: List[int], + protocol_list: List[str], + ) -> None: + """ + Initialize an ACL observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this ACL. + :type where: WhereType + :param num_rules: Number of ACL rules. + :type num_rules: int + :param ip_list: List of IP addresses. + :type ip_list: List[IPv4Address] + :param wildcard_list: List of wildcard strings. + :type wildcard_list: List[str] + :param port_list: List of port numbers. + :type port_list: List[int] + :param protocol_list: List of protocol names. + :type protocol_list: List[str] + """ + self.where = where + self.num_rules: int = num_rules + self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)} + self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)} + self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)} + self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)} + self.default_observation: Dict = { + i + + 1: { + "position": i, + "permission": 0, + "source_ip_id": 0, + "source_wildcard_id": 0, + "source_port_id": 0, + "dest_ip_id": 0, + "dest_wildcard_id": 0, + "dest_port_id": 0, + "protocol_id": 0, + } + for i in range(self.num_rules) + } + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing ACL rules. + :rtype: ObsType + """ + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation + obs = {} + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] + if rule_state is None: + obs[i] = { + "position": i - 1, + "permission": 0, + "source_ip_id": 0, + "source_wildcard_id": 0, + "source_port_id": 0, + "dest_ip_id": 0, + "dest_wildcard_id": 0, + "dest_port_id": 0, + "protocol_id": 0, + } + else: + src_ip = rule_state["src_ip_address"] + src_node_id = self.ip_to_id.get(src_ip, 1) + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = self.ip_to_id.get(dst_ip, 1) + src_wildcard = rule_state["source_wildcard_id"] + src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1) + dst_wildcard = rule_state["dest_wildcard_id"] + dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1) + src_port = rule_state["source_port_id"] + src_port_id = self.port_to_id.get(src_port, 1) + dst_port = rule_state["dest_port_id"] + dst_port_id = self.port_to_id.get(dst_port, 1) + protocol = rule_state["protocol"] + protocol_id = self.protocol_to_id.get(protocol, 1) + obs[i] = { + "position": i - 1, + "permission": rule_state["action"], + "source_ip_id": src_node_id, + "source_wildcard_id": src_wildcard_id, + "source_port_id": src_port_id, + "dest_ip_id": dst_node_ip, + "dest_wildcard_id": dst_wildcard_id, + "dest_port_id": dst_port_id, + "protocol_id": protocol_id, + } + i += 1 + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for ACL rules. + :rtype: spaces.Space + """ + return spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), + "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), + "source_port_id": spaces.Discrete(len(self.port_to_id) + 2), + "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), + "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), + "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), + "protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2), + } + ) + for i in range(self.num_rules) + } + ) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation: + """ + Create an ACL observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the ACL observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this ACL's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed ACL observation instance. + :rtype: ACLObservation + """ + return cls( + where=parent_where + ["acl", "acl"], + num_rules=config.num_rules, + ip_list=config.ip_list, + wildcard_list=config.wildcard_list, + port_list=config.port_list, + protocol_list=config.protocol_list, + ) diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py deleted file mode 100644 index 10370660..00000000 --- a/src/primaite/game/agent/observations/agent_observations.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING - -from gymnasium import spaces - -from primaite.game.agent.observations.node_observations import NodeObservation -from primaite.game.agent.observations.observations import ( - AbstractObservation, - AclObservation, - ICSObservation, - LinkObservation, - NullObservation, -) - -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - - -class UC2BlueObservation(AbstractObservation): - """Container for all observations used by the blue agent in UC2. - - TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler - for the purpose of compiling several observation components. - """ - - def __init__( - self, - nodes: List[NodeObservation], - links: List[LinkObservation], - acl: AclObservation, - ics: ICSObservation, - where: Optional[List[str]] = None, - ) -> None: - """Initialise UC2 blue observation. - - :param nodes: List of node observations - :type nodes: List[NodeObservation] - :param links: List of link observations - :type links: List[LinkObservation] - :param acl: The Access Control List observation - :type acl: AclObservation - :param ics: The ICS observation - :type ics: ICSObservation - :param where: Where in the simulation state dict to find information. Not used in this particular observation - because it only compiles other observations and doesn't contribute any new information, defaults to None - :type where: Optional[List[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.nodes: List[NodeObservation] = nodes - self.links: List[LinkObservation] = links - self.acl: AclObservation = acl - self.ics: ICSObservation = ics - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, - "ACL": self.acl.default_observation, - "ICS": self.ics.default_observation, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} - obs["ACL"] = self.acl.observe(state) - obs["ICS"] = self.ics.observe(state) - - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), - "ACL": self.acl.space, - "ICS": self.ics.space, - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": - """Create UC2 blue observation from a config. - - :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, - links, ACL and ICS observations. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed UC2 blue observation - :rtype: UC2BlueObservation - """ - node_configs = config["nodes"] - - num_services_per_node = config["num_services_per_node"] - num_folders_per_node = config["num_folders_per_node"] - num_files_per_folder = config["num_files_per_folder"] - num_nics_per_node = config["num_nics_per_node"] - nodes = [ - NodeObservation.from_config( - config=n, - game=game, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - for n in node_configs - ] - - link_configs = config["links"] - links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] - - acl_config = config["acl"] - acl = AclObservation.from_config(config=acl_config, game=game) - - ics_config = config["ics"] - ics = ICSObservation.from_config(config=ics_config, game=game) - new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) - return new - diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 277bc51f..a30bfc82 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -1,107 +1,130 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from __future__ import annotations + +from typing import Dict, Iterable, List, Optional from gymnasium import spaces +from gymnasium.core import ObsType from primaite import getLogger -from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame +class FileObservation(AbstractObservation, identifier="FILE"): + """File observation, provides status information about a file within the simulation environment.""" -class FileObservation(AbstractObservation): - """Observation of a file on a node in the network.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FileObservation.""" - def __init__(self, where: Optional[Tuple[str]] = None) -> None: + file_name: str + """Name of the file, used for querying simulation state dictionary.""" + include_num_access: Optional[bool] = None + """Whether to include the number of accesses to the file in the observation.""" + + def __init__(self, where: WhereType, include_num_access: bool) -> None: """ - Initialise file observation. + Initialize a file observation instance. - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',,'files',] - :type where: Optional[List[str]] + :param where: Where in the simulation state dictionary to find the relevant information for this file. + A typical location for a file might be + ['network', 'nodes', , 'file_system', 'folder', , 'files', ]. + :type where: WhereType + :param include_num_access: Whether to include the number of accesses to the file in the observation. + :type include_num_access: bool """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.default_observation: spaces.Space = {"health_status": 0} - "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." + self.where: WhereType = where + self.include_num_access: bool = include_num_access - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. + self.default_observation: ObsType = {"health_status": 0} + if self.include_num_access: + self.default_observation["num_access"] = 0 - :param state: Simulation state dictionary + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. :type state: Dict - :return: Observation - :rtype: Dict + :return: Observation containing the health status of the file and optionally the number of accesses. + :rtype: ObsType """ - if self.where is None: - return self.default_observation file_state = access_from_nested_dict(state, self.where) if file_state is NOT_PRESENT_IN_STATE: return self.default_observation - return {"health_status": file_state["visible_status"]} + obs = {"health_status": file_state["visible_status"]} + if self.include_num_access: + obs["num_access"] = file_state["num_access"] + # raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.") + return obs @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. + """ + Gymnasium space object describing the observation space shape. - :return: Gymnasium space + :return: Gymnasium space representing the observation space for file status. :rtype: spaces.Space """ - return spaces.Dict({"health_status": spaces.Discrete(6)}) + space = {"health_status": spaces.Discrete(6)} + if self.include_num_access: + space["num_access"] = spaces.Discrete(4) + return spaces.Dict(space) @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": - """Create file observation from a config. - - :param config: Dictionary containing the configuration for this file observation. - :type config: Dict - :param game: _description_ - :type game: PrimaiteGame - :param parent_where: _description_, defaults to None - :type parent_where: _type_, optional - :return: _description_ - :rtype: _type_ + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation: """ - return cls(where=parent_where + ["files", config["file_name"]]) + Create a file observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the file observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this file's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed file observation instance. + :rtype: FileObservation + """ + return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access) -class FolderObservation(AbstractObservation): - """Folder observation, including files inside of the folder.""" +class FolderObservation(AbstractObservation, identifier="FOLDER"): + """Folder observation, provides status information about a folder within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FolderObservation.""" + + folder_name: str + """Name of the folder, used for querying simulation state dictionary.""" + files: List[FileObservation.ConfigSchema] = [] + """List of file configurations within the folder.""" + num_files: Optional[int] = None + """Number of spaces for file observations in this folder.""" + include_num_access: Optional[bool] = None + """Whether files in this folder should include the number of accesses in their observation.""" def __init__( - self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 + self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool ) -> None: - """Initialise folder Observation, including files inside the folder. + """ + Initialize a folder observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this folder. - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',] - :type where: Optional[List[str]] - :param max_files: As size of the space must remain static, define max files that can be in this folder - , defaults to 5 - :type max_files: int, optional - :param file_positions: Defines the positioning within the observation space of particular files. This ensures - that even if new files are created, the existing files will always occupy the same space in the observation - space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the - observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same - name, it will take the position defined in this dict. Defaults to {} - :type file_positions: Dict[int, str], optional + A typical location for a folder might be ['network', 'nodes', , 'folders', ]. + :type where: WhereType + :param files: List of file observation instances within the folder. + :type files: Iterable[FileObservation] + :param num_files: Number of files expected in the folder. + :type num_files: int + :param include_num_access: Whether to include the number of accesses to files in the observation. + :type include_num_access: bool """ - super().__init__() - - self.where: Optional[Tuple[str]] = where + self.where: WhereType = where self.files: List[FileObservation] = files - while len(self.files) < num_files_per_folder: - self.files.append(FileObservation()) - while len(self.files) > num_files_per_folder: + while len(self.files) < num_files: + self.files.append(FileObservation(where=None, include_num_access=include_num_access)) + while len(self.files) > num_files: truncated_file = self.files.pop() msg = f"Too many files in folder observation. Truncating file {truncated_file}" _LOGGER.warning(msg) @@ -111,16 +134,15 @@ class FolderObservation(AbstractObservation): "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, } - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the health status of the folder and status of files within the folder. + :rtype: ObsType """ - if self.where is None: - return self.default_observation folder_state = access_from_nested_dict(state, self.where) if folder_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -136,9 +158,10 @@ class FolderObservation(AbstractObservation): @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. + """ + Gymnasium space object describing the observation space shape. - :return: Gymnasium space + :return: Gymnasium space representing the observation space for folder status. :rtype: spaces.Space """ return spaces.Dict( @@ -149,29 +172,23 @@ class FolderObservation(AbstractObservation): ) @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 - ) -> "FolderObservation": - """Create folder observation from a config. Also creates child file observations. + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation: + """ + Create a folder observation from a configuration schema. - :param config: Dictionary containing the configuration for this folder observation. Includes the name of the - folder and the files inside of it. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame + :param config: Configuration schema containing the necessary information for the folder observation. + :type config: ConfigSchema :param parent_where: Where in the simulation state dictionary to find the information about this folder's - parent node. A typical location for a node ``where`` can be: - ['network','nodes',,'file_system'] - :type parent_where: Optional[List[str]] - :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed folder observation + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed folder observation instance. :rtype: FolderObservation """ - where = parent_where + ["folders", config["folder_name"]] + where = parent_where + ["folders", config.folder_name] - file_configs = config["files"] - files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] + # pass down shared/common config items + for file_config in config.files: + file_config.include_num_access = config.include_num_access - return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) + files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] + return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py new file mode 100644 index 00000000..6397d473 --- /dev/null +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from typing import Dict, List, Optional + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.acl_observation import ACLObservation +from primaite.game.agent.observations.nic_observations import PortObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType + +_LOGGER = getLogger(__name__) + + +class FirewallObservation(AbstractObservation, identifier="FIREWALL"): + """Firewall observation, provides status information about a firewall within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FirewallObservation.""" + + hostname: str + """Hostname of the firewall node, used for querying simulation state dictionary.""" + ip_list: Optional[List[str]] = None + """List of IP addresses for encoding ACLs.""" + wildcard_list: Optional[List[str]] = None + """List of IP wildcards for encoding ACLs.""" + port_list: Optional[List[int]] = None + """List of ports for encoding ACLs.""" + protocol_list: Optional[List[str]] = None + """List of protocols for encoding ACLs.""" + num_rules: Optional[int] = None + """Number of rules ACL rules to show.""" + + def __init__( + self, + where: WhereType, + ip_list: List[str], + wildcard_list: List[str], + port_list: List[int], + protocol_list: List[str], + num_rules: int, + ) -> None: + """ + Initialize a firewall observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this firewall. + A typical location for a firewall might be ['network', 'nodes', ]. + :type where: WhereType + :param ip_list: List of IP addresses. + :type ip_list: List[str] + :param wildcard_list: List of wildcard rules. + :type wildcard_list: List[str] + :param port_list: List of port numbers. + :type port_list: List[int] + :param protocol_list: List of protocol types. + :type protocol_list: List[str] + :param num_rules: Number of rules configured in the firewall. + :type num_rules: int + """ + self.where: WhereType = where + + self.ports: List[PortObservation] = [ + PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3) + ] + # TODO: check what the port nums are for firewall. + + self.internal_inbound_acl = ACLObservation( + where=self.where + ["acl", "internal", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.internal_outbound_acl = ACLObservation( + where=self.where + ["acl", "internal", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.dmz_inbound_acl = ACLObservation( + where=self.where + ["acl", "dmz", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.dmz_outbound_acl = ACLObservation( + where=self.where + ["acl", "dmz", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.external_inbound_acl = ACLObservation( + where=self.where + ["acl", "external", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.external_outbound_acl = ACLObservation( + where=self.where + ["acl", "external", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + + self.default_observation = { + "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.default_observation, + "OUTBOUND": self.internal_outbound_acl.default_observation, + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.default_observation, + "OUTBOUND": self.dmz_outbound_acl.default_observation, + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.default_observation, + "OUTBOUND": self.external_outbound_acl.default_observation, + }, + } + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. + :rtype: ObsType + """ + obs = { + "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.observe(state), + "OUTBOUND": self.internal_outbound_acl.observe(state), + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.observe(state), + "OUTBOUND": self.dmz_outbound_acl.observe(state), + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.observe(state), + "OUTBOUND": self.external_outbound_acl.observe(state), + }, + } + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for firewall status. + :rtype: spaces.Space + """ + space = spaces.Dict( + { + "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), + "INTERNAL": spaces.Dict( + { + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + } + ), + "DMZ": spaces.Dict( + { + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + } + ), + "EXTERNAL": spaces.Dict( + { + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + } + ), + } + ) + return space + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: + """ + Create a firewall observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the firewall observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this firewall's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed firewall observation instance. + :rtype: FirewallObservation + """ + where = parent_where + ["nodes", config.hostname] + return cls( + where=where, + ip_list=config.ip_list, + wildcard_list=config.wildcard_list, + port_list=config.port_list, + protocol_list=config.protocol_list, + num_rules=config.num_rules, + ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py new file mode 100644 index 00000000..34c9b3ff --- /dev/null +++ b/src/primaite/game/agent/observations/host_observations.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +from typing import Dict, List, Optional + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.file_system_observations import FolderObservation +from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + + +class HostObservation(AbstractObservation, identifier="HOST"): + """Host observation, provides status information about a host within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for HostObservation.""" + + hostname: str + """Hostname of the host, used for querying simulation state dictionary.""" + services: List[ServiceObservation.ConfigSchema] = [] + """List of services to observe on the host.""" + applications: List[ApplicationObservation.ConfigSchema] = [] + """List of applications to observe on the host.""" + folders: List[FolderObservation.ConfigSchema] = [] + """List of folders to observe on the host.""" + network_interfaces: List[NICObservation.ConfigSchema] = [] + """List of network interfaces to observe on the host.""" + num_services: Optional[int] = None + """Number of spaces for service observations on this host.""" + num_applications: Optional[int] = None + """Number of spaces for application observations on this host.""" + num_folders: Optional[int] = None + """Number of spaces for folder observations on this host.""" + num_files: Optional[int] = None + """Number of spaces for file observations on this host.""" + num_nics: Optional[int] = None + """Number of spaces for network interface observations on this host.""" + include_nmne: Optional[bool] = None + """Whether network interface observations should include number of malicious network events.""" + include_num_access: Optional[bool] = None + """Whether to include the number of accesses to files observations on this host.""" + + def __init__( + self, + where: WhereType, + services: List[ServiceObservation], + applications: List[ApplicationObservation], + folders: List[FolderObservation], + network_interfaces: List[NICObservation], + num_services: int, + num_applications: int, + num_folders: int, + num_files: int, + num_nics: int, + include_nmne: bool, + include_num_access: bool, + ) -> None: + """ + Initialize a host observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this host. + A typical location for a host might be ['network', 'nodes', ]. + :type where: WhereType + :param services: List of service observations on the host. + :type services: List[ServiceObservation] + :param applications: List of application observations on the host. + :type applications: List[ApplicationObservation] + :param folders: List of folder observations on the host. + :type folders: List[FolderObservation] + :param network_interfaces: List of network interface observations on the host. + :type network_interfaces: List[NICObservation] + :param num_services: Number of services to observe. + :type num_services: int + :param num_applications: Number of applications to observe. + :type num_applications: int + :param num_folders: Number of folders to observe. + :type num_folders: int + :param num_files: Number of files. + :type num_files: int + :param num_nics: Number of network interfaces. + :type num_nics: int + :param include_nmne: Flag to include network metrics and errors. + :type include_nmne: bool + :param include_num_access: Flag to include the number of accesses to files. + :type include_num_access: bool + """ + self.where: WhereType = where + + # Ensure lists have lengths equal to specified counts by truncating or padding + self.services: List[ServiceObservation] = services + while len(self.services) < num_services: + self.services.append(ServiceObservation(where=None)) + while len(self.services) > num_services: + truncated_service = self.services.pop() + msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" + _LOGGER.warning(msg) + + self.applications: List[ApplicationObservation] = applications + while len(self.applications) < num_applications: + self.applications.append(ApplicationObservation(where=None)) + while len(self.applications) > num_applications: + truncated_application = self.applications.pop() + msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" + _LOGGER.warning(msg) + + self.folders: List[FolderObservation] = folders + while len(self.folders) < num_folders: + self.folders.append( + FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access) + ) + while len(self.folders) > num_folders: + truncated_folder = self.folders.pop() + msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}" + _LOGGER.warning(msg) + + self.network_interfaces: List[NICObservation] = network_interfaces + while len(self.network_interfaces) < num_nics: + self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne)) + while len(self.network_interfaces) > num_nics: + truncated_nic = self.network_interfaces.pop() + msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" + _LOGGER.warning(msg) + + self.default_observation: ObsType = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "operating_status": 0, + "num_file_creations": 0, + "num_file_deletions": 0, + } + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status information about the host. + :rtype: ObsType + """ + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NICS"] = { + i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) + } + obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] + obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for host status. + :rtype: spaces.Space + """ + shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NICS": spaces.Dict( + {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} + ), + "num_file_creations": spaces.Discrete(4), + "num_file_deletions": spaces.Discrete(4), + } + return spaces.Dict(shape) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation: + """ + Create a host observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the host observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this host. + A typical location might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed host observation instance. + :rtype: HostObservation + """ + if parent_where is None: + where = ["network", "nodes", config.hostname] + else: + where = parent_where + ["nodes", config.hostname] + + # Pass down shared/common config items + for folder_config in config.folders: + folder_config.include_num_access = config.include_num_access + folder_config.num_files = config.num_files + for nic_config in config.network_interfaces: + nic_config.include_nmne = config.include_nmne + + services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] + applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] + folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] + nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] + + return cls( + where=where, + services=services, + applications=applications, + folders=folders, + network_interfaces=nics, + num_services=config.num_services, + num_applications=config.num_applications, + num_folders=config.num_folders, + num_files=config.num_files, + num_nics=config.num_nics, + include_nmne=config.include_nmne, + include_num_access=config.include_num_access, + ) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index de83e03a..3be53112 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,188 +1,157 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from __future__ import annotations + +from typing import Dict, Optional from gymnasium import spaces +from gymnasium.core import ObsType -from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.nmne import CAPTURE_NMNE - -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame -class NicObservation(AbstractObservation): - """Observation of a Network Interface Card (NIC) in the network.""" +class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): + """Status information about a network interface within the simulation environment.""" - low_nmne_threshold: int = 0 - """The minimum number of malicious network events to be considered low.""" - med_nmne_threshold: int = 5 - """The minimum number of malicious network events to be considered medium.""" - high_nmne_threshold: int = 10 - """The minimum number of malicious network events to be considered high.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for NICObservation.""" - global CAPTURE_NMNE + nic_num: int + """Number of the network interface.""" + include_nmne: Optional[bool] = None + """Whether to include number of malicious network events (NMNE) in the observation.""" - @property - def default_observation(self) -> Dict: - """The default NIC observation dict.""" - data = {"nic_status": 0} - if CAPTURE_NMNE: - data.update({"NMNE": {"inbound": 0, "outbound": 0}}) - - return data - - def __init__( - self, - where: Optional[Tuple[str]] = None, - low_nmne_threshold: Optional[int] = 0, - med_nmne_threshold: Optional[int] = 5, - high_nmne_threshold: Optional[int] = 10, - ) -> None: - """Initialise NIC observation. - - :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical - example may look like this: - ['network','nodes',,'NICs',] - If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. - :type where: Optional[Tuple[str]], optional + def __init__(self, where: WhereType, include_nmne: bool) -> None: """ - super().__init__() - self.where: Optional[Tuple[str]] = where + Initialize a network interface observation instance. - global CAPTURE_NMNE - if CAPTURE_NMNE: - self.nmne_inbound_last_step: int = 0 - """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets - us find the difference.""" - self.nmne_outbound_last_step: int = 0 - """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets - us find the difference.""" - - if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold: - self._validate_nmne_categories( - low_nmne_threshold=low_nmne_threshold, - med_nmne_threshold=med_nmne_threshold, - high_nmne_threshold=high_nmne_threshold, - ) - - def _validate_nmne_categories( - self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10 - ): + :param where: Where in the simulation state dictionary to find the relevant information for this interface. + A typical location for a network interface might be + ['network', 'nodes', , 'NICs', ]. + :type where: WhereType + :param include_nmne: Flag to determine whether to include NMNE information in the observation. + :type include_nmne: bool """ - Validates the nmne threshold config. + self.where = where + self.include_nmne: bool = include_nmne - If the configuration is valid, the thresholds will be set, otherwise, an exception is raised. + self.default_observation: ObsType = {"nic_status": 0} + if self.include_nmne: + self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}}) - :param: low_nmne_threshold: The minimum number of malicious network events to be considered low - :param: med_nmne_threshold: The minimum number of malicious network events to be considered medium - :param: high_nmne_threshold: The minimum number of malicious network events to be considered high + def observe(self, state: Dict) -> ObsType: """ - if high_nmne_threshold <= med_nmne_threshold: - raise Exception( - f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater " - f"than medium nmne count ({med_nmne_threshold})" - ) + Generate observation based on the current state of the simulation. - if med_nmne_threshold <= low_nmne_threshold: - raise Exception( - f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater " - f"than low nmne count ({low_nmne_threshold})" - ) - - self.high_nmne_threshold = high_nmne_threshold - self.med_nmne_threshold = med_nmne_threshold - self.low_nmne_threshold = low_nmne_threshold - - def _categorise_mne_count(self, nmne_count: int) -> int: - """ - Categorise the number of Malicious Network Events (NMNEs) into discrete bins. - - This helps in classifying the severity or volume of MNEs into manageable levels for the agent. - - Bins are defined as follows: - - 0: No MNEs detected (0 events). - - 1: Low number of MNEs (default 1-5 events). - - 2: Moderate number of MNEs (default 6-10 events). - - 3: High number of MNEs (default more than 10 events). - - :param nmne_count: Number of MNEs detected. - :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. - """ - if nmne_count > self.high_nmne_threshold: - return 3 - elif nmne_count > self.med_nmne_threshold: - return 2 - elif nmne_count > self.low_nmne_threshold: - return 1 - return 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary + :param state: Simulation state dictionary. :type state: Dict - :return: Observation - :rtype: Dict + :return: Observation containing the status of the network interface and optionally NMNE information. + :rtype: ObsType """ - if self.where is None: - return self.default_observation nic_state = access_from_nested_dict(state, self.where) if nic_state is NOT_PRESENT_IN_STATE: return self.default_observation - else: - obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} - if CAPTURE_NMNE: - obs_dict.update({"NMNE": {}}) - direction_dict = nic_state["nmne"].get("direction", {}) - inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) - inbound_count = inbound_keywords.get("*", 0) - outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) - outbound_count = outbound_keywords.get("*", 0) - obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) - obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) - self.nmne_inbound_last_step = inbound_count - self.nmne_outbound_last_step = outbound_count - return obs_dict + + obs = {"nic_status": 1 if nic_state["enabled"] else 2} + if self.include_nmne: + obs.update({"NMNE": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + self.nmne_inbound_last_step = inbound_count + self.nmne_outbound_last_step = outbound_count + return obs @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for network interface status and NMNE information. + :rtype: spaces.Space + """ space = spaces.Dict({"nic_status": spaces.Discrete(3)}) - if CAPTURE_NMNE: + if self.include_nmne: space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) return space @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": - """Create NIC observation from a config. - - :param config: Dictionary containing the configuration for this NIC observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent - node. A typical location for a node ``where`` can be: ['network','nodes',] - :type parent_where: Optional[List[str]] - :return: Constructed NIC observation - :rtype: NicObservation + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: """ - low_nmne_threshold = None - med_nmne_threshold = None - high_nmne_threshold = None + Create a network interface observation from a configuration schema. - if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"): - threshold = game.options.thresholds["nmne"] + :param config: Configuration schema containing the necessary information for the network interface observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this NIC's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed network interface observation instance. + :rtype: NICObservation + """ + return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne) - low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None - med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None - high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None - return cls( - where=parent_where + ["NICs", config["nic_num"]], - low_nmne_threshold=low_nmne_threshold, - med_nmne_threshold=med_nmne_threshold, - high_nmne_threshold=high_nmne_threshold, - ) +class PortObservation(AbstractObservation, identifier="PORT"): + """Port observation, provides status information about a network port within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for PortObservation.""" + + port_id: int + """Identifier of the port, used for querying simulation state dictionary.""" + + def __init__(self, where: WhereType) -> None: + """ + Initialize a port observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this port. + A typical location for a port might be ['network', 'nodes', , 'NICs', ]. + :type where: WhereType + """ + self.where = where + self.default_observation: ObsType = {"operating_status": 0} + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the operating status of the port. + :rtype: ObsType + """ + port_state = access_from_nested_dict(state, self.where) + if port_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"operating_status": 1 if port_state["enabled"] else 2} + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for port status. + :rtype: spaces.Space + """ + return spaces.Dict({"operating_status": spaces.Discrete(3)}) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation: + """ + Create a port observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the port observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this port's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed port observation instance. + :rtype: PortObservation + """ + return cls(where=parent_where + ["NICs", config.port_id]) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index c702f8e2..0e63f440 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,1199 +1,18 @@ from __future__ import annotations -from ipaddress import IPv4Address -from typing import Any, Dict, Iterable, List, Optional +from typing import Dict, List from gymnasium import spaces from gymnasium.core import ObsType from primaite import getLogger -from primaite.game.agent.observations.observations import AbstractObservation -from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.game.agent.observations.firewall_observation import FirewallObservation +from primaite.game.agent.observations.host_observations import HostObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.observations.router_observation import RouterObservation _LOGGER = getLogger(__name__) -WhereType = Iterable[str | int] | None - - -class ServiceObservation(AbstractObservation, identifier="SERVICE"): - """Service observation, shows status of a service in the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for ServiceObservation.""" - - service_name: str - """Name of the service, used for querying simulation state dictionary""" - - def __init__(self, where: WhereType) -> None: - """ - Initialize a service observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this service. - A typical location for a service might be ['network', 'nodes', , 'services', ]. - :type where: WhereType - """ - self.where = where - self.default_observation = {"operating_status": 0, "health_status": 0} - - def observe(self, state: Dict) -> ObsType: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the operating status and health status of the service. - :rtype: Any - """ - service_state = access_from_nested_dict(state, self.where) - if service_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return { - "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], - } - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for service status. - :rtype: spaces.Space - """ - return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: - """ - Create a service observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the service observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this service's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed service observation instance. - :rtype: ServiceObservation - """ - return cls(where=parent_where + ["services", config.service_name]) - - -class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): - """Application observation, shows the status of an application within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for ApplicationObservation.""" - - application_name: str - """Name of the application, used for querying simulation state dictionary""" - - def __init__(self, where: WhereType) -> None: - """ - Initialise an application observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this application. - A typical location for an application might be - ['network', 'nodes', , 'applications', ]. - :type where: WhereType - """ - self.where = where - self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Obs containing the operating status, health status, and number of executions of the application. - :rtype: Any - """ - application_state = access_from_nested_dict(state, self.where) - if application_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return { - "operating_status": application_state["operating_state"], - "health_status": application_state["health_state_visible"], - "num_executions": application_state["num_executions"], - } - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for application status. - :rtype: spaces.Space - """ - return spaces.Dict( - { - "operating_status": spaces.Discrete(7), - "health_status": spaces.Discrete(5), - "num_executions": spaces.Discrete(4), - } - ) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation: - """ - Create an application observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the application observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this application's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed application observation instance. - :rtype: ApplicationObservation - """ - return cls(where=parent_where + ["applications", config.application_name]) - - -class FileObservation(AbstractObservation, identifier="FILE"): - """File observation, provides status information about a file within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for FileObservation.""" - - file_name: str - """Name of the file, used for querying simulation state dictionary.""" - include_num_access: Optional[bool] = None - """Whether to include the number of accesses to the file in the observation.""" - - def __init__(self, where: WhereType, include_num_access: bool) -> None: - """ - Initialize a file observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this file. - A typical location for a file might be - ['network', 'nodes', , 'file_system', 'folder', , 'files', ]. - :type where: WhereType - :param include_num_access: Whether to include the number of accesses to the file in the observation. - :type include_num_access: bool - """ - self.where: WhereType = where - self.include_num_access: bool = include_num_access - - self.default_observation: ObsType = {"health_status": 0} - if self.include_num_access: - self.default_observation["num_access"] = 0 - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the health status of the file and optionally the number of accesses. - :rtype: Any - """ - file_state = access_from_nested_dict(state, self.where) - if file_state is NOT_PRESENT_IN_STATE: - return self.default_observation - obs = {"health_status": file_state["visible_status"]} - if self.include_num_access: - obs["num_access"] = file_state["num_access"] - # raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.") - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for file status. - :rtype: spaces.Space - """ - space = {"health_status": spaces.Discrete(6)} - if self.include_num_access: - space["num_access"] = spaces.Discrete(4) - return spaces.Dict(space) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation: - """ - Create a file observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the file observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this file's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed file observation instance. - :rtype: FileObservation - """ - return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access) - - -class FolderObservation(AbstractObservation, identifier="FOLDER"): - """Folder observation, provides status information about a folder within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for FolderObservation.""" - - folder_name: str - """Name of the folder, used for querying simulation state dictionary.""" - files: List[FileObservation.ConfigSchema] = [] - """List of file configurations within the folder.""" - num_files: Optional[int] = None - """Number of spaces for file observations in this folder.""" - include_num_access: Optional[bool] = None - """Whether files in this folder should include the number of accesses in their observation.""" - - def __init__( - self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool - ) -> None: - """ - Initialize a folder observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this folder. - A typical location for a folder might be ['network', 'nodes', , 'folders', ]. - :type where: WhereType - :param files: List of file observation instances within the folder. - :type files: Iterable[FileObservation] - :param num_files: Number of files expected in the folder. - :type num_files: int - :param include_num_access: Whether to include the number of accesses to files in the observation. - :type include_num_access: bool - """ - self.where: WhereType = where - - self.files: List[FileObservation] = files - while len(self.files) < num_files: - self.files.append(FileObservation(where=None, include_num_access=include_num_access)) - while len(self.files) > num_files: - truncated_file = self.files.pop() - msg = f"Too many files in folder observation. Truncating file {truncated_file}" - _LOGGER.warning(msg) - - self.default_observation = { - "health_status": 0, - "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the health status of the folder and status of files within the folder. - :rtype: Any - """ - folder_state = access_from_nested_dict(state, self.where) - if folder_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - health_status = folder_state["health_status"] - - obs = {} - - obs["health_status"] = health_status - obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} - - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for folder status. - :rtype: spaces.Space - """ - return spaces.Dict( - { - "health_status": spaces.Discrete(6), - "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), - } - ) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation: - """ - Create a folder observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the folder observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this folder's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed folder observation instance. - :rtype: FolderObservation - """ - where = parent_where + ["folders", config.folder_name] - - # pass down shared/common config items - for file_config in config.files: - file_config.include_num_access = config.include_num_access - - files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] - return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) - - -class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): - """Status information about a network interface within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for NICObservation.""" - - nic_num: int - """Number of the network interface.""" - include_nmne: Optional[bool] = None - """Whether to include number of malicious network events (NMNE) in the observation.""" - - def __init__(self, where: WhereType, include_nmne: bool) -> None: - """ - Initialize a network interface observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this interface. - A typical location for a network interface might be - ['network', 'nodes', , 'NICs', ]. - :type where: WhereType - :param include_nmne: Flag to determine whether to include NMNE information in the observation. - :type include_nmne: bool - """ - self.where = where - self.include_nmne: bool = include_nmne - - self.default_observation: ObsType = {"nic_status": 0} - if self.include_nmne: - self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}}) - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the status of the network interface and optionally NMNE information. - :rtype: Any - """ - nic_state = access_from_nested_dict(state, self.where) - - if nic_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {"nic_status": 1 if nic_state["enabled"] else 2} - if self.include_nmne: - obs.update({"NMNE": {}}) - direction_dict = nic_state["nmne"].get("direction", {}) - inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) - inbound_count = inbound_keywords.get("*", 0) - outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) - outbound_count = outbound_keywords.get("*", 0) - obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) - obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) - self.nmne_inbound_last_step = inbound_count - self.nmne_outbound_last_step = outbound_count - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for network interface status and NMNE information. - :rtype: spaces.Space - """ - space = spaces.Dict({"nic_status": spaces.Discrete(3)}) - - if self.include_nmne: - space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) - - return space - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: - """ - Create a network interface observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the network interface observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this NIC's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed network interface observation instance. - :rtype: NICObservation - """ - return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne) - - -class HostObservation(AbstractObservation, identifier="HOST"): - """Host observation, provides status information about a host within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for HostObservation.""" - - hostname: str - """Hostname of the host, used for querying simulation state dictionary.""" - services: List[ServiceObservation.ConfigSchema] = [] - """List of services to observe on the host.""" - applications: List[ApplicationObservation.ConfigSchema] = [] - """List of applications to observe on the host.""" - folders: List[FolderObservation.ConfigSchema] = [] - """List of folders to observe on the host.""" - network_interfaces: List[NICObservation.ConfigSchema] = [] - """List of network interfaces to observe on the host.""" - num_services: Optional[int] = None - """Number of spaces for service observations on this host.""" - num_applications: Optional[int] = None - """Number of spaces for application observations on this host.""" - num_folders: Optional[int] = None - """Number of spaces for folder observations on this host.""" - num_files: Optional[int] = None - """Number of spaces for file observations on this host.""" - num_nics: Optional[int] = None - """Number of spaces for network interface observations on this host.""" - include_nmne: Optional[bool] = None - """Whether network interface observations should include number of malicious network events.""" - include_num_access: Optional[bool] = None - """Whether to include the number of accesses to files observations on this host.""" - - def __init__( - self, - where: WhereType, - services: List[ServiceObservation], - applications: List[ApplicationObservation], - folders: List[FolderObservation], - network_interfaces: List[NICObservation], - num_services: int, - num_applications: int, - num_folders: int, - num_files: int, - num_nics: int, - include_nmne: bool, - include_num_access: bool, - ) -> None: - """ - Initialize a host observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this host. - A typical location for a host might be ['network', 'nodes', ]. - :type where: WhereType - :param services: List of service observations on the host. - :type services: List[ServiceObservation] - :param applications: List of application observations on the host. - :type applications: List[ApplicationObservation] - :param folders: List of folder observations on the host. - :type folders: List[FolderObservation] - :param network_interfaces: List of network interface observations on the host. - :type network_interfaces: List[NICObservation] - :param num_services: Number of services to observe. - :type num_services: int - :param num_applications: Number of applications to observe. - :type num_applications: int - :param num_folders: Number of folders to observe. - :type num_folders: int - :param num_files: Number of files. - :type num_files: int - :param num_nics: Number of network interfaces. - :type num_nics: int - :param include_nmne: Flag to include network metrics and errors. - :type include_nmne: bool - :param include_num_access: Flag to include the number of accesses to files. - :type include_num_access: bool - """ - self.where: WhereType = where - - # Ensure lists have lengths equal to specified counts by truncating or padding - self.services: List[ServiceObservation] = services - while len(self.services) < num_services: - self.services.append(ServiceObservation(where=None)) - while len(self.services) > num_services: - truncated_service = self.services.pop() - msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warning(msg) - - self.applications: List[ApplicationObservation] = applications - while len(self.applications) < num_applications: - self.applications.append(ApplicationObservation(where=None)) - while len(self.applications) > num_applications: - truncated_application = self.applications.pop() - msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" - _LOGGER.warning(msg) - - self.folders: List[FolderObservation] = folders - while len(self.folders) < num_folders: - self.folders.append( - FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access) - ) - while len(self.folders) > num_folders: - truncated_folder = self.folders.pop() - msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}" - _LOGGER.warning(msg) - - self.network_interfaces: List[NICObservation] = network_interfaces - while len(self.network_interfaces) < num_nics: - self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne)) - while len(self.network_interfaces) > num_nics: - truncated_nic = self.network_interfaces.pop() - msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" - _LOGGER.warning(msg) - - self.default_observation: ObsType = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, - "operating_status": 0, - "num_file_creations": 0, - "num_file_deletions": 0, - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the status information about the host. - :rtype: Any - """ - node_state = access_from_nested_dict(state, self.where) - if node_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - obs["operating_status"] = node_state["operating_state"] - obs["NICS"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } - obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] - obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for host status. - :rtype: spaces.Space - """ - shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(5), - "NICS": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), - "num_file_creations": spaces.Discrete(4), - "num_file_deletions": spaces.Discrete(4), - } - return spaces.Dict(shape) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation: - """ - Create a host observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the host observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this host. - A typical location might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed host observation instance. - :rtype: HostObservation - """ - if parent_where is None: - where = ["network", "nodes", config.hostname] - else: - where = parent_where + ["nodes", config.hostname] - - # Pass down shared/common config items - for folder_config in config.folders: - folder_config.include_num_access = config.include_num_access - folder_config.num_files = config.num_files - for nic_config in config.network_interfaces: - nic_config.include_nmne = config.include_nmne - - services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] - applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] - folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] - nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] - - return cls( - where=where, - services=services, - applications=applications, - folders=folders, - network_interfaces=nics, - num_services=config.num_services, - num_applications=config.num_applications, - num_folders=config.num_folders, - num_files=config.num_files, - num_nics=config.num_nics, - include_nmne=config.include_nmne, - include_num_access=config.include_num_access, - ) - - -class PortObservation(AbstractObservation, identifier="PORT"): - """Port observation, provides status information about a network port within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for PortObservation.""" - - port_id: int - """Identifier of the port, used for querying simulation state dictionary.""" - - def __init__(self, where: WhereType) -> None: - """ - Initialize a port observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this port. - A typical location for a port might be ['network', 'nodes', , 'NICs', ]. - :type where: WhereType - """ - self.where = where - self.default_observation: ObsType = {"operating_status": 0} - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the operating status of the port. - :rtype: Any - """ - port_state = access_from_nested_dict(state, self.where) - if port_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return {"operating_status": 1 if port_state["enabled"] else 2} - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for port status. - :rtype: spaces.Space - """ - return spaces.Dict({"operating_status": spaces.Discrete(3)}) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation: - """ - Create a port observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the port observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this port's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed port observation instance. - :rtype: PortObservation - """ - return cls(where=parent_where + ["NICs", config.port_id]) - - -class ACLObservation(AbstractObservation, identifier="ACL"): - """ACL observation, provides information about access control lists within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for ACLObservation.""" - - ip_list: Optional[List[IPv4Address]] = None - """List of IP addresses.""" - wildcard_list: Optional[List[str]] = None - """List of wildcard strings.""" - port_list: Optional[List[int]] = None - """List of port numbers.""" - protocol_list: Optional[List[str]] = None - """List of protocol names.""" - num_rules: Optional[int] = None - """Number of ACL rules.""" - - def __init__( - self, - where: WhereType, - num_rules: int, - ip_list: List[IPv4Address], - wildcard_list: List[str], - port_list: List[int], - protocol_list: List[str], - ) -> None: - """ - Initialize an ACL observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this ACL. - :type where: WhereType - :param num_rules: Number of ACL rules. - :type num_rules: int - :param ip_list: List of IP addresses. - :type ip_list: List[IPv4Address] - :param wildcard_list: List of wildcard strings. - :type wildcard_list: List[str] - :param port_list: List of port numbers. - :type port_list: List[int] - :param protocol_list: List of protocol names. - :type protocol_list: List[str] - """ - self.where = where - self.num_rules: int = num_rules - self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)} - self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)} - self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)} - self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)} - self.default_observation: Dict = { - i - + 1: { - "position": i, - "permission": 0, - "source_ip_id": 0, - "source_wildcard_id": 0, - "source_port_id": 0, - "dest_ip_id": 0, - "dest_wildcard_id": 0, - "dest_port_id": 0, - "protocol_id": 0, - } - for i in range(self.num_rules) - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing ACL rules. - :rtype: Any - """ - acl_state: Dict = access_from_nested_dict(state, self.where) - if acl_state is NOT_PRESENT_IN_STATE: - return self.default_observation - obs = {} - acl_items = dict(acl_state.items()) - i = 1 # don't show rule 0 for compatibility reasons. - while i < self.num_rules + 1: - rule_state = acl_items[i] - if rule_state is None: - obs[i] = { - "position": i - 1, - "permission": 0, - "source_ip_id": 0, - "source_wildcard_id": 0, - "source_port_id": 0, - "dest_ip_id": 0, - "dest_wildcard_id": 0, - "dest_port_id": 0, - "protocol_id": 0, - } - else: - src_ip = rule_state["src_ip_address"] - src_node_id = self.ip_to_id.get(src_ip, 1) - dst_ip = rule_state["dst_ip_address"] - dst_node_ip = self.ip_to_id.get(dst_ip, 1) - src_wildcard = rule_state["source_wildcard_id"] - src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1) - dst_wildcard = rule_state["dest_wildcard_id"] - dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1) - src_port = rule_state["source_port_id"] - src_port_id = self.port_to_id.get(src_port, 1) - dst_port = rule_state["dest_port_id"] - dst_port_id = self.port_to_id.get(dst_port, 1) - protocol = rule_state["protocol"] - protocol_id = self.protocol_to_id.get(protocol, 1) - obs[i] = { - "position": i - 1, - "permission": rule_state["action"], - "source_ip_id": src_node_id, - "source_wildcard_id": src_wildcard_id, - "source_port_id": src_port_id, - "dest_ip_id": dst_node_ip, - "dest_wildcard_id": dst_wildcard_id, - "dest_port_id": dst_port_id, - "protocol_id": protocol_id, - } - i += 1 - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for ACL rules. - :rtype: spaces.Space - """ - return spaces.Dict( - { - i - + 1: spaces.Dict( - { - "position": spaces.Discrete(self.num_rules), - "permission": spaces.Discrete(3), - # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), - "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), - "source_port_id": spaces.Discrete(len(self.port_to_id) + 2), - "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), - "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), - "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), - "protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2), - } - ) - for i in range(self.num_rules) - } - ) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation: - """ - Create an ACL observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the ACL observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this ACL's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed ACL observation instance. - :rtype: ACLObservation - """ - return cls( - where=parent_where + ["acl", "acl"], - num_rules=config.num_rules, - ip_list=config.ip_list, - wildcard_list=config.wildcard_list, - port_list=config.port_list, - protocol_list=config.protocol_list, - ) - - -class RouterObservation(AbstractObservation, identifier="ROUTER"): - """Router observation, provides status information about a router within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for RouterObservation.""" - - hostname: str - """Hostname of the router, used for querying simulation state dictionary.""" - ports: Optional[List[PortObservation.ConfigSchema]] = None - """Configuration of port observations for this router.""" - num_ports: Optional[int] = None - """Number of port observations configured for this router.""" - acl: Optional[ACLObservation.ConfigSchema] = None - """Configuration of ACL observation on this router.""" - ip_list: Optional[List[str]] = None - """List of IP addresses for encoding ACLs.""" - wildcard_list: Optional[List[str]] = None - """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None - """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None - """List of protocols for encoding ACLs.""" - num_rules: Optional[int] = None - """Number of rules ACL rules to show.""" - - def __init__( - self, - where: WhereType, - ports: List[PortObservation], - num_ports: int, - acl: ACLObservation, - ) -> None: - """ - Initialize a router observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this router. - A typical location for a router might be ['network', 'nodes', ]. - :type where: WhereType - :param ports: List of port observations representing the ports of the router. - :type ports: List[PortObservation] - :param num_ports: Number of ports for the router. - :type num_ports: int - :param acl: ACL observation representing the access control list of the router. - :type acl: ACLObservation - """ - self.where: WhereType = where - self.ports: List[PortObservation] = ports - self.acl: ACLObservation = acl - self.num_ports: int = num_ports - - while len(self.ports) < num_ports: - self.ports.append(PortObservation(where=None)) - while len(self.ports) > num_ports: - self.ports.pop() - msg = "Too many ports in router observation. Truncating." - _LOGGER.warning(msg) - - self.default_observation = { - "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, - "ACL": self.acl.default_observation, - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the status of ports and ACL configuration of the router. - :rtype: Any - """ - router_state = access_from_nested_dict(state, self.where) - if router_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} - obs["ACL"] = self.acl.observe(state) - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for router status. - :rtype: spaces.Space - """ - return spaces.Dict( - {"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space} - ) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation: - """ - Create a router observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the router observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this router's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed router observation instance. - :rtype: RouterObservation - """ - where = parent_where + ["nodes", config.hostname] - - if config.acl is None: - config.acl = ACLObservation.ConfigSchema() - if config.acl.num_rules is None: - config.acl.num_rules = config.num_rules - if config.acl.ip_list is None: - config.acl.ip_list = config.ip_list - if config.acl.wildcard_list is None: - config.acl.wildcard_list = config.wildcard_list - if config.acl.port_list is None: - config.acl.port_list = config.port_list - if config.acl.protocol_list is None: - config.acl.protocol_list = config.protocol_list - - if config.ports is None: - config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)] - - ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] - acl = ACLObservation.from_config(config=config.acl, parent_where=where) - return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) - - -class FirewallObservation(AbstractObservation, identifier="FIREWALL"): - """Firewall observation, provides status information about a firewall within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for FirewallObservation.""" - - hostname: str - """Hostname of the firewall node, used for querying simulation state dictionary.""" - ip_list: Optional[List[str]] = None - """List of IP addresses for encoding ACLs.""" - wildcard_list: Optional[List[str]] = None - """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None - """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None - """List of protocols for encoding ACLs.""" - num_rules: Optional[int] = None - """Number of rules ACL rules to show.""" - - def __init__( - self, - where: WhereType, - ip_list: List[str], - wildcard_list: List[str], - port_list: List[int], - protocol_list: List[str], - num_rules: int, - ) -> None: - """ - Initialize a firewall observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this firewall. - A typical location for a firewall might be ['network', 'nodes', ]. - :type where: WhereType - :param ip_list: List of IP addresses. - :type ip_list: List[str] - :param wildcard_list: List of wildcard rules. - :type wildcard_list: List[str] - :param port_list: List of port numbers. - :type port_list: List[int] - :param protocol_list: List of protocol types. - :type protocol_list: List[str] - :param num_rules: Number of rules configured in the firewall. - :type num_rules: int - """ - self.where: WhereType = where - - self.ports: List[PortObservation] = [ - PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3) - ] - # TODO: check what the port nums are for firewall. - - self.internal_inbound_acl = ACLObservation( - where=self.where + ["acl", "internal", "inbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.internal_outbound_acl = ACLObservation( - where=self.where + ["acl", "internal", "outbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.dmz_inbound_acl = ACLObservation( - where=self.where + ["acl", "dmz", "inbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.dmz_outbound_acl = ACLObservation( - where=self.where + ["acl", "dmz", "outbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.external_inbound_acl = ACLObservation( - where=self.where + ["acl", "external", "inbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.external_outbound_acl = ACLObservation( - where=self.where + ["acl", "external", "outbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - - self.default_observation = { - "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.default_observation, - "OUTBOUND": self.internal_outbound_acl.default_observation, - }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.default_observation, - "OUTBOUND": self.dmz_outbound_acl.default_observation, - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.default_observation, - "OUTBOUND": self.external_outbound_acl.default_observation, - }, - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. - :rtype: Any - """ - obs = { - "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.observe(state), - "OUTBOUND": self.internal_outbound_acl.observe(state), - }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.observe(state), - "OUTBOUND": self.dmz_outbound_acl.observe(state), - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.observe(state), - "OUTBOUND": self.external_outbound_acl.observe(state), - }, - } - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for firewall status. - :rtype: spaces.Space - """ - space = spaces.Dict( - { - "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), - "INTERNAL": spaces.Dict( - { - "INBOUND": self.internal_inbound_acl.space, - "OUTBOUND": self.internal_outbound_acl.space, - } - ), - "DMZ": spaces.Dict( - { - "INBOUND": self.dmz_inbound_acl.space, - "OUTBOUND": self.dmz_outbound_acl.space, - } - ), - "EXTERNAL": spaces.Dict( - { - "INBOUND": self.external_inbound_acl.space, - "OUTBOUND": self.external_outbound_acl.space, - } - ), - } - ) - return space - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: - """ - Create a firewall observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the firewall observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this firewall's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed firewall observation instance. - :rtype: FirewallObservation - """ - where = parent_where + ["nodes", config.hostname] - return cls( - where=where, - ip_list=config.ip_list, - wildcard_list=config.wildcard_list, - port_list=config.port_list, - protocol_list=config.protocol_list, - num_rules=config.num_rules, - ) - class NodesObservation(AbstractObservation, identifier="NODES"): """Nodes observation, provides status information about nodes within the simulation environment.""" @@ -1266,14 +85,14 @@ class NodesObservation(AbstractObservation, identifier="NODES"): **{f"FIREWALL{i}": firewall.default_observation for i, firewall in enumerate(self.firewalls)}, } - def observe(self, state: Dict) -> Any: + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. :param state: Simulation state dictionary. :type state: Dict :return: Observation containing status information about nodes. - :rtype: Any + :rtype: ObsType """ obs = { **{f"HOST{i}": host.observe(state) for i, host in enumerate(self.hosts)}, @@ -1300,7 +119,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation: """ Create a nodes observation from a configuration schema. diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index dc41e8e5..08871072 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,24 +1,23 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Type +from typing import Any, Dict, Iterable, Type from gymnasium import spaces from pydantic import BaseModel, ConfigDict from primaite import getLogger -from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame +WhereType = Iterable[str | int] | None class AbstractObservation(ABC): """Abstract class for an observation space component.""" class ConfigSchema(ABC, BaseModel): + """Config schema for observations.""" + model_config = ConfigDict(extra="forbid") _registry: Dict[str, Type["AbstractObservation"]] = {} @@ -61,269 +60,271 @@ class AbstractObservation(ABC): @classmethod def from_config(cls, cfg: Dict) -> "AbstractObservation": """Create this observation space component form a serialised format.""" - ObservationType = cls._registry[cfg['type']] + ObservationType = cls._registry[cfg["type"]] return ObservationType.from_config(cfg=cfg) -# class LinkObservation(AbstractObservation): -# """Observation of a link in the network.""" +''' +class LinkObservation(AbstractObservation): + """Observation of a link in the network.""" -# default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} -# "Default observation is what should be returned when the link doesn't exist." + default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} + "Default observation is what should be returned when the link doesn't exist." -# def __init__(self, where: Optional[Tuple[str]] = None) -> None: -# """Initialise link observation. + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise link observation. -# :param where: Store information about where in the simulation state dictionary to find the relevant information. -# Optional. If None, this corresponds that the file does not exist and the observation will be populated with -# zeroes. + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. -# A typical location for a service looks like this: -# `['network','nodes',,'servics', ]` -# :type where: Optional[List[str]] -# """ -# super().__init__() -# self.where: Optional[Tuple[str]] = where + A typical location for a service looks like this: + `['network','nodes',,'servics', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where -# def observe(self, state: Dict) -> Dict: -# """Generate observation based on the current state of the simulation. + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. -# :param state: Simulation state dictionary -# :type state: Dict -# :return: Observation -# :rtype: Dict -# """ -# if self.where is None: -# return self.default_observation + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation -# link_state = access_from_nested_dict(state, self.where) -# if link_state is NOT_PRESENT_IN_STATE: -# return self.default_observation + link_state = access_from_nested_dict(state, self.where) + if link_state is NOT_PRESENT_IN_STATE: + return self.default_observation -# bandwidth = link_state["bandwidth"] -# load = link_state["current_load"] -# if load == 0: -# utilisation_category = 0 -# else: -# utilisation_fraction = load / bandwidth -# # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% -# utilisation_category = int(utilisation_fraction * 9) + 1 + bandwidth = link_state["bandwidth"] + load = link_state["current_load"] + if load == 0: + utilisation_category = 0 + else: + utilisation_fraction = load / bandwidth + # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% + utilisation_category = int(utilisation_fraction * 9) + 1 -# # TODO: once the links support separte load per protocol, this needs amendment to reflect that. -# return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} + # TODO: once the links support separte load per protocol, this needs amendment to reflect that. + return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} -# @property -# def space(self) -> spaces.Space: -# """Gymnasium space object describing the observation space shape. + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. -# :return: Gymnasium space -# :rtype: spaces.Space -# """ -# return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) -# @classmethod -# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": -# """Create link observation from a config. + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": + """Create link observation from a config. -# :param config: Dictionary containing the configuration for this link observation. -# :type config: Dict -# :param game: Reference to the PrimaiteGame object that spawned this observation. -# :type game: PrimaiteGame -# :return: Constructed link observation -# :rtype: LinkObservation -# """ -# return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) + :param config: Dictionary containing the configuration for this link observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Constructed link observation + :rtype: LinkObservation + """ + return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) -# class AclObservation(AbstractObservation): -# """Observation of an Access Control List (ACL) in the network.""" +class AclObservation(AbstractObservation): + """Observation of an Access Control List (ACL) in the network.""" -# # TODO: should where be optional, and we can use where=None to pad the observation space? -# # definitely the current approach does not support tracking files that aren't specified by name, for example -# # if a file is created at runtime, we have currently got no way of telling the observation space to track it. -# # this needs adding, but not for the MVP. -# def __init__( -# self, -# node_ip_to_id: Dict[str, int], -# ports: List[int], -# protocols: List[str], -# where: Optional[Tuple[str]] = None, -# num_rules: int = 10, -# ) -> None: -# """Initialise ACL observation. + # TODO: should where be optional, and we can use where=None to pad the observation space? + # definitely the current approach does not support tracking files that aren't specified by name, for example + # if a file is created at runtime, we have currently got no way of telling the observation space to track it. + # this needs adding, but not for the MVP. + def __init__( + self, + node_ip_to_id: Dict[str, int], + ports: List[int], + protocols: List[str], + where: Optional[Tuple[str]] = None, + num_rules: int = 10, + ) -> None: + """Initialise ACL observation. -# :param node_ip_to_id: Mapping between IP address and ID. -# :type node_ip_to_id: Dict[str, int] -# :param ports: List of ports which are part of the game that define the ordering when converting to an ID -# :type ports: List[int] -# :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID -# :type protocols: list[str] -# :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical -# example may look like this: -# ['network','nodes',,'acl','acl'] -# :type where: Optional[Tuple[str]], optional -# :param num_rules: , defaults to 10 -# :type num_rules: int, optional -# """ -# super().__init__() -# self.where: Optional[Tuple[str]] = where -# self.num_rules: int = num_rules -# self.node_to_id: Dict[str, int] = node_ip_to_id -# "List of node IP addresses, order in this list determines how they are converted to an ID" -# self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} -# "List of ports which are part of the game that define the ordering when converting to an ID" -# self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} -# "List of protocols which are part of the game, defines ordering when converting to an ID" -# self.default_observation: Dict = { -# i -# + 1: { -# "position": i, -# "permission": 0, -# "source_node_id": 0, -# "source_port": 0, -# "dest_node_id": 0, -# "dest_port": 0, -# "protocol": 0, -# } -# for i in range(self.num_rules) -# } + :param node_ip_to_id: Mapping between IP address and ID. + :type node_ip_to_id: Dict[str, int] + :param ports: List of ports which are part of the game that define the ordering when converting to an ID + :type ports: List[int] + :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID + :type protocols: list[str] + :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical + example may look like this: + ['network','nodes',,'acl','acl'] + :type where: Optional[Tuple[str]], optional + :param num_rules: , defaults to 10 + :type num_rules: int, optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + self.num_rules: int = num_rules + self.node_to_id: Dict[str, int] = node_ip_to_id + "List of node IP addresses, order in this list determines how they are converted to an ID" + self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} + "List of ports which are part of the game that define the ordering when converting to an ID" + self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} + "List of protocols which are part of the game, defines ordering when converting to an ID" + self.default_observation: Dict = { + i + + 1: { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + for i in range(self.num_rules) + } -# def observe(self, state: Dict) -> Dict: -# """Generate observation based on the current state of the simulation. + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. -# :param state: Simulation state dictionary -# :type state: Dict -# :return: Observation -# :rtype: Dict -# """ -# if self.where is None: -# return self.default_observation -# acl_state: Dict = access_from_nested_dict(state, self.where) -# if acl_state is NOT_PRESENT_IN_STATE: -# return self.default_observation + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation -# # TODO: what if the ACL has more rules than num of max rules for obs space -# obs = {} -# acl_items = dict(acl_state.items()) -# i = 1 # don't show rule 0 for compatibility reasons. -# while i < self.num_rules + 1: -# rule_state = acl_items[i] -# if rule_state is None: -# obs[i] = { -# "position": i - 1, -# "permission": 0, -# "source_node_id": 0, -# "source_port": 0, -# "dest_node_id": 0, -# "dest_port": 0, -# "protocol": 0, -# } -# else: -# src_ip = rule_state["src_ip_address"] -# src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] -# dst_ip = rule_state["dst_ip_address"] -# dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] -# src_port = rule_state["src_port"] -# src_port_id = 1 if src_port is None else self.port_to_id[src_port] -# dst_port = rule_state["dst_port"] -# dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] -# protocol = rule_state["protocol"] -# protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] -# obs[i] = { -# "position": i - 1, -# "permission": rule_state["action"], -# "source_node_id": src_node_id, -# "source_port": src_port_id, -# "dest_node_id": dst_node_ip, -# "dest_port": dst_port_id, -# "protocol": protocol_id, -# } -# i += 1 -# return obs + # TODO: what if the ACL has more rules than num of max rules for obs space + obs = {} + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] + if rule_state is None: + obs[i] = { + "position": i - 1, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + else: + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, + "permission": rule_state["action"], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, + } + i += 1 + return obs -# @property -# def space(self) -> spaces.Space: -# """Gymnasium space object describing the observation space shape. + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. -# :return: Gymnasium space -# :rtype: spaces.Space -# """ -# return spaces.Dict( -# { -# i -# + 1: spaces.Dict( -# { -# "position": spaces.Discrete(self.num_rules), -# "permission": spaces.Discrete(3), -# # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) -# "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), -# "source_port": spaces.Discrete(len(self.port_to_id) + 2), -# "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), -# "dest_port": spaces.Discrete(len(self.port_to_id) + 2), -# "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), -# } -# ) -# for i in range(self.num_rules) -# } -# ) + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "source_port": spaces.Discrete(len(self.port_to_id) + 2), + "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), + } + ) + for i in range(self.num_rules) + } + ) -# @classmethod -# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": -# """Generate ACL observation from a config. + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": + """Generate ACL observation from a config. -# :param config: Dictionary containing the configuration for this ACL observation. -# :type config: Dict -# :param game: Reference to the PrimaiteGame object that spawned this observation. -# :type game: PrimaiteGame -# :return: Observation object -# :rtype: AclObservation -# """ -# max_acl_rules = config["options"]["max_acl_rules"] -# node_ip_to_idx = {} -# for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): -# node_ref = ip_map_config["node_hostname"] -# nic_num = ip_map_config["nic_num"] -# node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] -# nic_obj = node_obj.network_interface[nic_num] -# node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 + :param config: Dictionary containing the configuration for this ACL observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Observation object + :rtype: AclObservation + """ + max_acl_rules = config["options"]["max_acl_rules"] + node_ip_to_idx = {} + for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): + node_ref = ip_map_config["node_hostname"] + nic_num = ip_map_config["nic_num"] + node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] + nic_obj = node_obj.network_interface[nic_num] + node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 -# router_hostname = config["router_hostname"] -# return cls( -# node_ip_to_id=node_ip_to_idx, -# ports=game.options.ports, -# protocols=game.options.protocols, -# where=["network", "nodes", router_hostname, "acl", "acl"], -# num_rules=max_acl_rules, -# ) + router_hostname = config["router_hostname"] + return cls( + node_ip_to_id=node_ip_to_idx, + ports=game.options.ports, + protocols=game.options.protocols, + where=["network", "nodes", router_hostname, "acl", "acl"], + num_rules=max_acl_rules, + ) -# class NullObservation(AbstractObservation): -# """Null observation, returns a single 0 value for the observation space.""" +class NullObservation(AbstractObservation): + """Null observation, returns a single 0 value for the observation space.""" -# def __init__(self, where: Optional[List[str]] = None): -# """Initialise null observation.""" -# self.default_observation: Dict = {} + def __init__(self, where: Optional[List[str]] = None): + """Initialise null observation.""" + self.default_observation: Dict = {} -# def observe(self, state: Dict) -> Dict: -# """Generate observation based on the current state of the simulation.""" -# return 0 + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation.""" + return 0 -# @property -# def space(self) -> spaces.Space: -# """Gymnasium space object describing the observation space shape.""" -# return spaces.Discrete(1) + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Discrete(1) -# @classmethod -# def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": -# """ -# Create null observation from a config. + @classmethod + def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": + """ + Create null observation from a config. -# The parameters are ignored, they are here to match the signature of the other observation classes. -# """ -# return cls() + The parameters are ignored, they are here to match the signature of the other observation classes. + """ + return cls() -# class ICSObservation(NullObservation): -# """ICS observation placeholder, currently not implemented so always returns a single 0.""" +class ICSObservation(NullObservation): + """ICS observation placeholder, currently not implemented so always returns a single 0.""" -# pass + pass +''' diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py new file mode 100644 index 00000000..b8dee2c2 --- /dev/null +++ b/src/primaite/game/agent/observations/router_observation.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import Dict, List, Optional + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.acl_observation import ACLObservation +from primaite.game.agent.observations.nic_observations import PortObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + + +class RouterObservation(AbstractObservation, identifier="ROUTER"): + """Router observation, provides status information about a router within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for RouterObservation.""" + + hostname: str + """Hostname of the router, used for querying simulation state dictionary.""" + ports: Optional[List[PortObservation.ConfigSchema]] = None + """Configuration of port observations for this router.""" + num_ports: Optional[int] = None + """Number of port observations configured for this router.""" + acl: Optional[ACLObservation.ConfigSchema] = None + """Configuration of ACL observation on this router.""" + ip_list: Optional[List[str]] = None + """List of IP addresses for encoding ACLs.""" + wildcard_list: Optional[List[str]] = None + """List of IP wildcards for encoding ACLs.""" + port_list: Optional[List[int]] = None + """List of ports for encoding ACLs.""" + protocol_list: Optional[List[str]] = None + """List of protocols for encoding ACLs.""" + num_rules: Optional[int] = None + """Number of rules ACL rules to show.""" + + def __init__( + self, + where: WhereType, + ports: List[PortObservation], + num_ports: int, + acl: ACLObservation, + ) -> None: + """ + Initialize a router observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this router. + A typical location for a router might be ['network', 'nodes', ]. + :type where: WhereType + :param ports: List of port observations representing the ports of the router. + :type ports: List[PortObservation] + :param num_ports: Number of ports for the router. + :type num_ports: int + :param acl: ACL observation representing the access control list of the router. + :type acl: ACLObservation + """ + self.where: WhereType = where + self.ports: List[PortObservation] = ports + self.acl: ACLObservation = acl + self.num_ports: int = num_ports + + while len(self.ports) < num_ports: + self.ports.append(PortObservation(where=None)) + while len(self.ports) > num_ports: + self.ports.pop() + msg = "Too many ports in router observation. Truncating." + _LOGGER.warning(msg) + + self.default_observation = { + "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, + "ACL": self.acl.default_observation, + } + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status of ports and ACL configuration of the router. + :rtype: ObsType + """ + router_state = access_from_nested_dict(state, self.where) + if router_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + obs["ACL"] = self.acl.observe(state) + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for router status. + :rtype: spaces.Space + """ + return spaces.Dict( + {"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space} + ) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation: + """ + Create a router observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the router observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this router's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed router observation instance. + :rtype: RouterObservation + """ + where = parent_where + ["nodes", config.hostname] + + if config.acl is None: + config.acl = ACLObservation.ConfigSchema() + if config.acl.num_rules is None: + config.acl.num_rules = config.num_rules + if config.acl.ip_list is None: + config.acl.ip_list = config.ip_list + if config.acl.wildcard_list is None: + config.acl.wildcard_list = config.wildcard_list + if config.acl.port_list is None: + config.acl.port_list = config.port_list + if config.acl.protocol_list is None: + config.acl.protocol_list = config.protocol_list + + if config.ports is None: + config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)] + + ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] + acl = ACLObservation.from_config(config=config.acl, parent_where=where) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 6caf791c..eb94651d 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,45 +1,43 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from __future__ import annotations + +from typing import Dict from gymnasium import spaces +from gymnasium.core import ObsType -from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame +class ServiceObservation(AbstractObservation, identifier="SERVICE"): + """Service observation, shows status of a service in the simulation environment.""" -class ServiceObservation(AbstractObservation): - """Observation of a service in the network.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ServiceObservation.""" - default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} - "Default observation is what should be returned when the service doesn't exist." + service_name: str + """Name of the service, used for querying simulation state dictionary""" - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise service observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'services', ]` - :type where: Optional[List[str]] + def __init__(self, where: WhereType) -> None: """ - super().__init__() - self.where: Optional[Tuple[str]] = where + Initialize a service observation instance. - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. + :param where: Where in the simulation state dictionary to find the relevant information for this service. + A typical location for a service might be ['network', 'nodes', , 'services', ]. + :type where: WhereType + """ + self.where = where + self.default_observation = {"operating_status": 0, "health_status": 0} - :param state: Simulation state dictionary + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. :type state: Dict - :return: Observation - :rtype: Dict + :return: Observation containing the operating status and health status of the service. + :rtype: ObsType """ - if self.where is None: - return self.default_observation - service_state = access_from_nested_dict(state, self.where) if service_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -50,114 +48,96 @@ class ServiceObservation(AbstractObservation): @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for service status. + :rtype: spaces.Space + """ return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None - ) -> "ServiceObservation": - """Create service observation from a config. + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: + """ + Create a service observation from a configuration schema. - :param config: Dictionary containing the configuration for this service observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. - :type parent_where: Optional[List[str]], optional - :return: Constructed service observation + :param config: Configuration schema containing the necessary information for the service observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this service's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed service observation instance. :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", config["service_name"]]) + return cls(where=parent_where + ["services", config.service_name]) -class ApplicationObservation(AbstractObservation): - """Observation of an application in the network.""" +class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): + """Application observation, shows the status of an application within the simulation environment.""" - default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0} - "Default observation is what should be returned when the application doesn't exist." + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ApplicationObservation.""" - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise application observation. + application_name: str + """Name of the application, used for querying simulation state dictionary""" - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'applications', ]` - :type where: Optional[List[str]] + def __init__(self, where: WhereType) -> None: """ - super().__init__() - self.where: Optional[Tuple[str]] = where + Initialise an application observation instance. - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. + :param where: Where in the simulation state dictionary to find the relevant information for this application. + A typical location for an application might be + ['network', 'nodes', , 'applications', ]. + :type where: WhereType + """ + self.where = where + self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - :param state: Simulation state dictionary + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. :type state: Dict - :return: Observation - :rtype: Dict + :return: Obs containing the operating status, health status, and number of executions of the application. + :rtype: ObsType """ - if self.where is None: - return self.default_observation - - app_state = access_from_nested_dict(state, self.where) - if app_state is NOT_PRESENT_IN_STATE: + application_state = access_from_nested_dict(state, self.where) + if application_state is NOT_PRESENT_IN_STATE: return self.default_observation return { - "operating_status": app_state["operating_state"], - "health_status": app_state["health_state_visible"], - "num_executions": self._categorise_num_executions(app_state["num_executions"]), + "operating_status": application_state["operating_state"], + "health_status": application_state["health_state_visible"], + "num_executions": application_state["num_executions"], } @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for application status. + :rtype: spaces.Space + """ return spaces.Dict( { "operating_status": spaces.Discrete(7), - "health_status": spaces.Discrete(6), + "health_status": spaces.Discrete(5), "num_executions": spaces.Discrete(4), } ) @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None - ) -> "ApplicationObservation": - """Create application observation from a config. + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation: + """ + Create an application observation from a configuration schema. - :param config: Dictionary containing the configuration for this service observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. - :type parent_where: Optional[List[str]], optional - :return: Constructed service observation + :param config: Configuration schema containing the necessary information for the application observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this application's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed application observation instance. :rtype: ApplicationObservation """ - return cls(where=parent_where + ["services", config["application_name"]]) - - @classmethod - def _categorise_num_executions(cls, num_executions: int) -> int: - """ - Categorise the number of executions of an application. - - Helps classify the number of application executions into different categories. - - Current categories: - - 0: Application is never executed - - 1: Application is executed a low number of times (1-5) - - 2: Application is executed often (6-10) - - 3: Application is executed a high number of times (more than 10) - - :param: num_executions: Number of times the application is executed - """ - if num_executions > 10: - return 3 - elif num_executions > 5: - return 2 - elif num_executions > 0: - return 1 - return 0 + return cls(where=parent_where + ["applications", config.application_name]) From 15cb2e6970a184c83c6d56c01ad3ae3f26660b1e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 31 Mar 2024 17:31:10 +0100 Subject: [PATCH 09/16] #2417 Add NestedObservation --- .../agent/observations/acl_observation.py | 2 +- .../observations/file_system_observations.py | 4 +- .../observations/firewall_observation.py | 2 +- .../agent/observations/host_observations.py | 2 +- .../agent/observations/nic_observations.py | 4 +- .../agent/observations/node_observations.py | 2 +- .../agent/observations/observation_manager.py | 136 ++++++++++++++++-- .../game/agent/observations/observations.py | 11 +- .../agent/observations/router_observation.py | 2 +- .../observations/software_observation.py | 2 +- 10 files changed, 140 insertions(+), 27 deletions(-) diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 2d29223d..7601e678 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -40,7 +40,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): protocol_list: List[str], ) -> None: """ - Initialize an ACL observation instance. + Initialise an ACL observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this ACL. :type where: WhereType diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index a30bfc82..3c931bc8 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -25,7 +25,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): def __init__(self, where: WhereType, include_num_access: bool) -> None: """ - Initialize a file observation instance. + Initialise a file observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this file. A typical location for a file might be @@ -107,7 +107,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool ) -> None: """ - Initialize a folder observation instance. + Initialise a folder observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this folder. A typical location for a folder might be ['network', 'nodes', , 'folders', ]. diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 6397d473..376e4824 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -42,7 +42,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): num_rules: int, ) -> None: """ - Initialize a firewall observation instance. + Initialise a firewall observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this firewall. A typical location for a firewall might be ['network', 'nodes', ]. diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 34c9b3ff..9146979a 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -62,7 +62,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_num_access: bool, ) -> None: """ - Initialize a host observation instance. + Initialise a host observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this host. A typical location for a host might be ['network', 'nodes', ]. diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 3be53112..ff2731ff 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -22,7 +22,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): def __init__(self, where: WhereType, include_nmne: bool) -> None: """ - Initialize a network interface observation instance. + Initialise a network interface observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this interface. A typical location for a network interface might be @@ -108,7 +108,7 @@ class PortObservation(AbstractObservation, identifier="PORT"): def __init__(self, where: WhereType) -> None: """ - Initialize a port observation instance. + Initialise a port observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this port. A typical location for a port might be ['network', 'nodes', , 'NICs', ]. diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 0e63f440..3f384ece 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -61,7 +61,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewalls: List[FirewallObservation], ) -> None: """ - Initialize a nodes observation instance. + Initialise a nodes observation instance. :param where: Where in the simulation state dictionary to find the relevant information for nodes. A typical location for nodes might be ['network', 'nodes']. diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index be90041e..a6981ddc 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -1,6 +1,10 @@ -from typing import Dict, TYPE_CHECKING +from __future__ import annotations +from typing import Any, Dict, List, TYPE_CHECKING + +from gymnasium import spaces from gymnasium.core import ObsType +from pydantic import BaseModel, ConfigDict, model_validator, ValidationError from primaite.game.agent.observations.observations import AbstractObservation @@ -8,6 +12,114 @@ if TYPE_CHECKING: from primaite.game.game import PrimaiteGame +class NestedObservation(AbstractObservation, identifier="CUSTOM"): + """Observation type that allows combining other observations into a gymnasium.spaces.Dict space.""" + + class NestedObservationItem(BaseModel): + """One list item of the config.""" + + model_config = ConfigDict(extra="forbid") + type: str + """Select observation class. It maps to the identifier of the obs class by checking the registry.""" + label: str + """Dict key in the final observation space.""" + options: Dict + """Options to pass to the observation class from_config method.""" + + @model_validator(mode="after") + def check_model(self) -> "NestedObservation.NestedObservationItem": + """Make sure tha the config options match up with the selected observation type.""" + obs_subclass_name = self.type + obs_options = self.options + if obs_subclass_name not in AbstractObservation._registry: + raise ValueError(f"Observation of type {obs_subclass_name} could not be found.") + obs_schema = AbstractObservation._registry[obs_subclass_name].ConfigSchema + try: + obs_schema(**obs_options) + except ValidationError as e: + raise ValueError(f"Observation options did not match schema, got this error: {e}") + return self + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for NestedObservation.""" + + components: List[NestedObservation.NestedObservationItem] + """List of observation components to be part of this space.""" + + def __init__(self, components: Dict[str, AbstractObservation]) -> None: + """Initialise nested observation.""" + self.components: Dict[str, AbstractObservation] = components + """Maps label: observation object""" + + self.default_observation = {label: obs.default_observation for label, obs in self.components.items()} + """Default observation is just the default observations of constituents.""" + + def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status information about the host. + :rtype: ObsType + """ + return {label: obs.observe(state) for label, obs in self.components.items()} + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the nested observation space. + :rtype: spaces.Space + """ + return spaces.Dict({label: obs.space for label, obs in self.components.items()}) + + @classmethod + def from_config(cls, config: ConfigSchema) -> NestedObservation: + """ + Read the Nested observation config and create all defined subcomponents. + + Example configuration that utilises NestedObservation: + This lets us have different options for different types of hosts. + + ```yaml + observation_space: + - type: CUSTOM + options: + components: + + - type: HOSTS + label: COMPUTERS # What is the dictionary key called + options: + hosts: + - client_1 + - client_2 + num_services: 0 + num_applications: 5 + ... # other options + + - type: HOSTS + label: SERVERS # What is the dictionary key called + options: + hosts: + - hostname: database_server + - hostname: web_server + num_services: 4 + num_applications: 0 + num_folders: 2 + num_files: 2 + + ``` + """ + instances = dict() + for component in config.components: + obs_class = AbstractObservation._registry[component.type] + obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options)) + instances[component.label] = obs_instance + return cls(components=instances) + + class ObservationManager: """ Manage the observations of an Agent. @@ -18,18 +130,15 @@ class ObservationManager: 3. Formatting this information so an agent can use it to make decisions. """ - # TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed - # to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next - # refactor. - - def __init__(self, observation: AbstractObservation) -> None: + def __init__(self, obs: AbstractObservation) -> None: """Initialise observation space. :param observation: Observation object :type observation: AbstractObservation """ - self.obs: AbstractObservation = observation + self.obs: AbstractObservation = obs self.current_observation: ObsType + """Cached copy of the observation at the time it was most recently calculated.""" def update(self, state: Dict) -> Dict: """ @@ -48,7 +157,8 @@ class ObservationManager: @classmethod def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": - """Create observation space from a config. + """ + Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. It should contain the key 'type' which selects which observation class to use (from a choice of: @@ -58,10 +168,8 @@ class ObservationManager: :param game: Reference to the PrimaiteGame object that spawned this observation. :type game: PrimaiteGame """ - - for obs_cfg in config: - obs_type = obs_cfg['type'] - obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(obs_class.ConfigSchema(**obs_cfg['options'])) + obs_type = config["type"] + obs_class = AbstractObservation._registry[obs_type] + observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"])) obs_manager = cls(observation) - return obs_manager \ No newline at end of file + return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 08871072..feddc3ed 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, Type from gymnasium import spaces +from gymnasium.core import ObsType from pydantic import BaseModel, ConfigDict from primaite import getLogger @@ -26,6 +27,10 @@ class AbstractObservation(ABC): Automatically populated when subclasses are defined. Used for defining from_config. """ + def __init__(self) -> None: + """Initialise an observation. This method must be overwritten.""" + self.default_observation: ObsType + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: """ Register an observation type. @@ -58,10 +63,10 @@ class AbstractObservation(ABC): pass @classmethod - def from_config(cls, cfg: Dict) -> "AbstractObservation": + @abstractmethod + def from_config(cls, config: ConfigSchema) -> "AbstractObservation": """Create this observation space component form a serialised format.""" - ObservationType = cls._registry[cfg["type"]] - return ObservationType.from_config(cfg=cfg) + return cls() ''' diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index b8dee2c2..97d8ab41 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -47,7 +47,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): acl: ACLObservation, ) -> None: """ - Initialize a router observation instance. + Initialise a router observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this router. A typical location for a router might be ['network', 'nodes', ]. diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index eb94651d..0c031345 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -20,7 +20,7 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): def __init__(self, where: WhereType) -> None: """ - Initialize a service observation instance. + Initialise a service observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this service. A typical location for a service might be ['network', 'nodes', , 'services', ]. From 62ebca8c08e9966cb52d29d01e5a98b7cbb9aff8 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 31 Mar 2024 21:39:24 +0100 Subject: [PATCH 10/16] #2417 Remove references to old obs names and add link obs --- CHANGELOG.md | 2 +- .../network/network_interfaces.rst | 2 +- .../agent/observations/acl_observation.py | 6 +- .../observations/file_system_observations.py | 10 +- .../observations/firewall_observation.py | 8 +- .../agent/observations/host_observations.py | 18 +- .../agent/observations/link_observation.py | 155 ++++++++++ .../agent/observations/nic_observations.py | 9 +- .../agent/observations/node_observations.py | 12 +- .../agent/observations/observation_manager.py | 14 +- .../game/agent/observations/observations.py | 275 +----------------- .../agent/observations/router_observation.py | 10 +- .../observations/software_observation.py | 13 +- src/primaite/simulator/network/nmne.py | 2 +- tests/conftest.py | 5 +- .../observations/test_acl_observations.py | 4 +- .../observations/test_link_observations.py | 2 +- .../observations/test_nic_observations.py | 12 +- .../observations/test_node_observations.py | 4 +- .../network/test_capture_nmne.py | 10 +- .../_game/_agent/test_probabilistic_agent.py | 5 +- 21 files changed, 247 insertions(+), 331 deletions(-) create mode 100644 src/primaite/game/agent/observations/link_observation.py diff --git a/CHANGELOG.md b/CHANGELOG.md index c01f0139..8931a3d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -119,7 +119,7 @@ SessionManager. - Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios. - **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules. - Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them. -- Integration of NMNE capturing functionality within the `NicObservation` class. +- Integration of NMNE capturing functionality within the `NICObservation` class. - Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario ### Removed diff --git a/docs/source/simulation_components/network/network_interfaces.rst b/docs/source/simulation_components/network/network_interfaces.rst index ffba58e4..f50a1baa 100644 --- a/docs/source/simulation_components/network/network_interfaces.rst +++ b/docs/source/simulation_components/network/network_interfaces.rst @@ -73,7 +73,7 @@ Network Interface Classes - Malicious Network Events Monitoring: * Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns. - * Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies. + * Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NICObservation`` to classify and record network anomalies. * Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms. **WiredNetworkInterface (Connection Type Layer)** diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 7601e678..ac599ea0 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -1,7 +1,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -10,6 +10,8 @@ from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -165,7 +167,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): ) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> ACLObservation: """ Create an ACL observation from a configuration schema. diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 3c931bc8..a7c56a89 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -9,6 +9,8 @@ from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -73,7 +75,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): return spaces.Dict(space) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FileObservation: """ Create a file observation from a configuration schema. @@ -172,7 +174,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): ) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation: """ Create a folder observation from a configuration schema. @@ -190,5 +192,5 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): for file_config in config.files: file_config.include_num_access = config.include_num_access - files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] + files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in config.files] return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 376e4824..69398d96 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -10,6 +10,8 @@ from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -190,7 +192,9 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: + def from_config( + cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] + ) -> FirewallObservation: """ Create a firewall observation from a configuration schema. diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 9146979a..d71583b3 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -12,6 +12,8 @@ from primaite.game.agent.observations.observations import AbstractObservation, W from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -184,7 +186,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): return spaces.Dict(shape) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> HostObservation: """ Create a host observation from a configuration schema. @@ -196,7 +198,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): :return: Constructed host observation instance. :rtype: HostObservation """ - if parent_where is None: + if parent_where == []: where = ["network", "nodes", config.hostname] else: where = parent_where + ["nodes", config.hostname] @@ -208,10 +210,12 @@ class HostObservation(AbstractObservation, identifier="HOST"): for nic_config in config.network_interfaces: nic_config.include_nmne = config.include_nmne - services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] - applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] - folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] - nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] + services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in config.services] + applications = [ + ApplicationObservation.from_config(config=c, game=game, parent_where=where) for c in config.applications + ] + folders = [FolderObservation.from_config(config=c, game=game, parent_where=where) for c in config.folders] + nics = [NICObservation.from_config(config=c, game=game, parent_where=where) for c in config.network_interfaces] return cls( where=where, diff --git a/src/primaite/game/agent/observations/link_observation.py b/src/primaite/game/agent/observations/link_observation.py new file mode 100644 index 00000000..f810bb36 --- /dev/null +++ b/src/primaite/game/agent/observations/link_observation.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from typing import Any, Dict, List, TYPE_CHECKING + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame +_LOGGER = getLogger(__name__) + + +class LinkObservation(AbstractObservation, identifier="LINK"): + """Link observation, providing information about a specific link within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for LinkObservation.""" + + link_reference: str + """Reference identifier for the link.""" + + def __init__(self, where: WhereType) -> None: + """ + Initialise a link observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this link. + A typical location for a link might be ['network', 'links', ]. + :type where: WhereType + """ + self.where = where + self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}} + + def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing information about the link. + :rtype: Any + """ + link_state = access_from_nested_dict(state, self.where) + if link_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + bandwidth = link_state["bandwidth"] + load = link_state["current_load"] + if load == 0: + utilisation_category = 0 + else: + utilisation_fraction = load / bandwidth + utilisation_category = int(utilisation_fraction * 9) + 1 + + return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for link status. + :rtype: spaces.Space + """ + return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) + + @classmethod + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinkObservation: + """ + Create a link observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the link observation. + :type config: ConfigSchema + :param game: The PrimaiteGame instance. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this link. + A typical location might be ['network', 'links', ]. + :type parent_where: WhereType, optional + :return: Constructed link observation instance. + :rtype: LinkObservation + """ + link_reference = game.ref_map_links[config.link_reference] + if parent_where == []: + where = ["network", "links", link_reference] + else: + where = parent_where + ["links", link_reference] + return cls(where=where) + + +class LinksObservation(AbstractObservation, identifier="LINKS"): + """Collection of link observations representing multiple links within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for LinksObservation.""" + + link_references: List[str] + """List of reference identifiers for the links.""" + + def __init__(self, where: WhereType, links: List[LinkObservation]) -> None: + """ + Initialise a links observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for these links. + A typical location for links might be ['network', 'links']. + :type where: WhereType + :param links: List of link observations. + :type links: List[LinkObservation] + """ + self.where: WhereType = where + self.links: List[LinkObservation] = links + self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)} + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing information about multiple links. + :rtype: ObsType + """ + return {i + 1: l.observe(state) for i, l in enumerate(self.links)} + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for multiple links. + :rtype: spaces.Space + """ + return {i + 1: l.space for i, l in enumerate(self.links)} + + @classmethod + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation: + """ + Create a links observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the links observation. + :type config: ConfigSchema + :param game: The PrimaiteGame instance. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about these links. + A typical location might be ['network']. + :type parent_where: WhereType, optional + :return: Constructed links observation instance. + :rtype: LinksObservation + """ + where = parent_where + ["network"] + link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references] + links = [LinkObservation.from_config(c, game=game, parent_where=where) for c in link_cfgs] + return cls(where=where, links=links) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index ff2731ff..19826f84 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, Optional +from typing import Dict, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -8,6 +8,9 @@ from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): """Status information about a network interface within the simulation environment.""" @@ -82,7 +85,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NICObservation: """ Create a network interface observation from a configuration schema. @@ -142,7 +145,7 @@ class PortObservation(AbstractObservation, identifier="PORT"): return spaces.Dict({"operating_status": spaces.Discrete(3)}) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> PortObservation: """ Create a port observation from a configuration schema. diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 3f384ece..7d227bb7 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List +from typing import Dict, List, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -11,6 +11,8 @@ from primaite.game.agent.observations.host_observations import HostObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.observations.router_observation import RouterObservation +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -119,7 +121,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NodesObservation: """ Create a nodes observation from a configuration schema. @@ -178,8 +180,8 @@ class NodesObservation(AbstractObservation, identifier="NODES"): if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules - hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] - routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] - firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls] + hosts = [HostObservation.from_config(config=c, game=game, parent_where=where) for c in config.hosts] + routers = [RouterObservation.from_config(config=c, game=game, parent_where=where) for c in config.routers] + firewalls = [FirewallObservation.from_config(config=c, game=game, parent_where=where) for c in config.firewalls] return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls) diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index a6981ddc..84311984 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -1,12 +1,12 @@ from __future__ import annotations -from typing import Any, Dict, List, TYPE_CHECKING +from typing import Dict, List, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType from pydantic import BaseModel, ConfigDict, model_validator, ValidationError -from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType if TYPE_CHECKING: from primaite.game.game import PrimaiteGame @@ -43,7 +43,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for NestedObservation.""" - components: List[NestedObservation.NestedObservationItem] + components: List[NestedObservation.NestedObservationItem] = [] """List of observation components to be part of this space.""" def __init__(self, components: Dict[str, AbstractObservation]) -> None: @@ -54,7 +54,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): self.default_observation = {label: obs.default_observation for label, obs in self.components.items()} """Default observation is just the default observations of constituents.""" - def observe(self, state: Dict) -> Any: + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -76,7 +76,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): return spaces.Dict({label: obs.space for label, obs in self.components.items()}) @classmethod - def from_config(cls, config: ConfigSchema) -> NestedObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NestedObservation: """ Read the Nested observation config and create all defined subcomponents. @@ -115,7 +115,7 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): instances = dict() for component in config.components: obs_class = AbstractObservation._registry[component.type] - obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options)) + obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options), game=game) instances[component.label] = obs_instance return cls(components=instances) @@ -170,6 +170,6 @@ class ObservationManager: """ obs_type = config["type"] obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"])) + observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game) obs_manager = cls(observation) return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index feddc3ed..6c9db571 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,6 +1,6 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from typing import Any, Dict, Iterable, Type +from typing import Any, Dict, Iterable, Type, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -8,8 +8,9 @@ from pydantic import BaseModel, ConfigDict from primaite import getLogger +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) - WhereType = Iterable[str | int] | None @@ -64,272 +65,8 @@ class AbstractObservation(ABC): @classmethod @abstractmethod - def from_config(cls, config: ConfigSchema) -> "AbstractObservation": + def from_config( + cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] + ) -> "AbstractObservation": """Create this observation space component form a serialised format.""" return cls() - - -''' -class LinkObservation(AbstractObservation): - """Observation of a link in the network.""" - - default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} - "Default observation is what should be returned when the link doesn't exist." - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise link observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'servics', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - link_state = access_from_nested_dict(state, self.where) - if link_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - bandwidth = link_state["bandwidth"] - load = link_state["current_load"] - if load == 0: - utilisation_category = 0 - else: - utilisation_fraction = load / bandwidth - # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% - utilisation_category = int(utilisation_fraction * 9) + 1 - - # TODO: once the links support separte load per protocol, this needs amendment to reflect that. - return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": - """Create link observation from a config. - - :param config: Dictionary containing the configuration for this link observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed link observation - :rtype: LinkObservation - """ - return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) - - -class AclObservation(AbstractObservation): - """Observation of an Access Control List (ACL) in the network.""" - - # TODO: should where be optional, and we can use where=None to pad the observation space? - # definitely the current approach does not support tracking files that aren't specified by name, for example - # if a file is created at runtime, we have currently got no way of telling the observation space to track it. - # this needs adding, but not for the MVP. - def __init__( - self, - node_ip_to_id: Dict[str, int], - ports: List[int], - protocols: List[str], - where: Optional[Tuple[str]] = None, - num_rules: int = 10, - ) -> None: - """Initialise ACL observation. - - :param node_ip_to_id: Mapping between IP address and ID. - :type node_ip_to_id: Dict[str, int] - :param ports: List of ports which are part of the game that define the ordering when converting to an ID - :type ports: List[int] - :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID - :type protocols: list[str] - :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical - example may look like this: - ['network','nodes',,'acl','acl'] - :type where: Optional[Tuple[str]], optional - :param num_rules: , defaults to 10 - :type num_rules: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.num_rules: int = num_rules - self.node_to_id: Dict[str, int] = node_ip_to_id - "List of node IP addresses, order in this list determines how they are converted to an ID" - self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} - "List of ports which are part of the game that define the ordering when converting to an ID" - self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} - "List of protocols which are part of the game, defines ordering when converting to an ID" - self.default_observation: Dict = { - i - + 1: { - "position": i, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - for i in range(self.num_rules) - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - acl_state: Dict = access_from_nested_dict(state, self.where) - if acl_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - # TODO: what if the ACL has more rules than num of max rules for obs space - obs = {} - acl_items = dict(acl_state.items()) - i = 1 # don't show rule 0 for compatibility reasons. - while i < self.num_rules + 1: - rule_state = acl_items[i] - if rule_state is None: - obs[i] = { - "position": i - 1, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - else: - src_ip = rule_state["src_ip_address"] - src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] - dst_ip = rule_state["dst_ip_address"] - dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] - src_port = rule_state["src_port"] - src_port_id = 1 if src_port is None else self.port_to_id[src_port] - dst_port = rule_state["dst_port"] - dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] - protocol = rule_state["protocol"] - protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] - obs[i] = { - "position": i - 1, - "permission": rule_state["action"], - "source_node_id": src_node_id, - "source_port": src_port_id, - "dest_node_id": dst_node_ip, - "dest_port": dst_port_id, - "protocol": protocol_id, - } - i += 1 - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - i - + 1: spaces.Dict( - { - "position": spaces.Discrete(self.num_rules), - "permission": spaces.Discrete(3), - # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "source_port": spaces.Discrete(len(self.port_to_id) + 2), - "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "dest_port": spaces.Discrete(len(self.port_to_id) + 2), - "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), - } - ) - for i in range(self.num_rules) - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": - """Generate ACL observation from a config. - - :param config: Dictionary containing the configuration for this ACL observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Observation object - :rtype: AclObservation - """ - max_acl_rules = config["options"]["max_acl_rules"] - node_ip_to_idx = {} - for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): - node_ref = ip_map_config["node_hostname"] - nic_num = ip_map_config["nic_num"] - node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] - nic_obj = node_obj.network_interface[nic_num] - node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 - - router_hostname = config["router_hostname"] - return cls( - node_ip_to_id=node_ip_to_idx, - ports=game.options.ports, - protocols=game.options.protocols, - where=["network", "nodes", router_hostname, "acl", "acl"], - num_rules=max_acl_rules, - ) - - -class NullObservation(AbstractObservation): - """Null observation, returns a single 0 value for the observation space.""" - - def __init__(self, where: Optional[List[str]] = None): - """Initialise null observation.""" - self.default_observation: Dict = {} - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - return 0 - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Discrete(1) - - @classmethod - def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": - """ - Create null observation from a config. - - The parameters are ignored, they are here to match the signature of the other observation classes. - """ - return cls() - - -class ICSObservation(NullObservation): - """ICS observation placeholder, currently not implemented so always returns a single 0.""" - - pass -''' diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index 97d8ab41..c2919770 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, Optional +from typing import Dict, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -11,6 +11,8 @@ from primaite.game.agent.observations.nic_observations import PortObservation from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame _LOGGER = getLogger(__name__) @@ -107,7 +109,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): ) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation: + def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation: """ Create a router observation from a configuration schema. @@ -137,6 +139,6 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): if config.ports is None: config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)] - ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] - acl = ACLObservation.from_config(config=config.acl, parent_where=where) + ports = [PortObservation.from_config(config=c, game=game, parent_where=where) for c in config.ports] + acl = ACLObservation.from_config(config=config.acl, game=game, parent_where=where) return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 0c031345..40788760 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict +from typing import Dict, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -8,6 +8,9 @@ from gymnasium.core import ObsType from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + class ServiceObservation(AbstractObservation, identifier="SERVICE"): """Service observation, shows status of a service in the simulation environment.""" @@ -57,7 +60,9 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: + def from_config( + cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] + ) -> ServiceObservation: """ Create a service observation from a configuration schema. @@ -128,7 +133,9 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): ) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation: + def from_config( + cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] + ) -> ApplicationObservation: """ Create an application observation from a configuration schema. diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py index 87839712..1b3d838d 100644 --- a/src/primaite/simulator/network/nmne.py +++ b/src/primaite/simulator/network/nmne.py @@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True NMNE_CAPTURE_KEYWORDS: List[str] = [] """List of keywords to identify malicious network events.""" -# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically +# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically CAPTURE_BY_DIRECTION: Final[bool] = True """Flag to determine if captures should be organized by traffic direction (inbound/outbound).""" CAPTURE_BY_IP_ADDRESS: Final[bool] = False diff --git a/tests/conftest.py b/tests/conftest.py index 078a78bd..b08fd838 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,8 +10,7 @@ from _pytest.monkeypatch import MonkeyPatch from primaite import getLogger, PRIMAITE_PATHS from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent -from primaite.game.agent.observations.observation_manager import ObservationManager -from primaite.game.agent.observations.observations import ICSObservation +from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.game import PrimaiteGame from primaite.session.session import PrimaiteSession @@ -525,7 +524,7 @@ def game_and_agent(): ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"], act_map={}, ) - observation_space = ObservationManager(ICSObservation()) + observation_space = ObservationManager(NestedObservation(components={})) reward_function = RewardFunction() test_agent = ControlledAgent( diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index 93867edd..d0710f5f 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -1,6 +1,6 @@ import pytest -from primaite.game.agent.observations.observations import AclObservation +from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.transport_layer import Port @@ -34,7 +34,7 @@ def test_acl_observations(simulation): # add router acl rule router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) - acl_obs = AclObservation( + acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], node_ip_to_id={}, ports=["NTP", "HTTP", "POSTGRES_SERVER"], diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index bfe4d5cc..b13314f1 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -1,7 +1,7 @@ import pytest from gymnasium import spaces -from primaite.game.agent.observations.observations import LinkObservation +from primaite.game.agent.observations.link_observation import LinkObservation from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.base import Link, Node from primaite.simulator.network.hardware.nodes.host.computer import Computer diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index 332bc1f7..bc4261ce 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -5,7 +5,7 @@ import pytest import yaml from gymnasium import spaces -from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.agent.observations.nic_observations import NICObservation from primaite.game.game import PrimaiteGame from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC @@ -40,7 +40,7 @@ def test_nic(simulation): nic: NIC = pc.network_interface[1] - nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) assert nic_obs.space["nic_status"] == spaces.Discrete(3) assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) @@ -61,13 +61,13 @@ def test_nic_categories(simulation): """Test the NIC observation nmne count categories.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") - nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) assert nic_obs.high_nmne_threshold == 10 # default assert nic_obs.med_nmne_threshold == 5 # default assert nic_obs.low_nmne_threshold == 0 # default - nic_obs = NicObservation( + nic_obs = NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], low_nmne_threshold=3, med_nmne_threshold=6, @@ -80,7 +80,7 @@ def test_nic_categories(simulation): with pytest.raises(Exception): # should throw an error - NicObservation( + NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], low_nmne_threshold=9, med_nmne_threshold=6, @@ -89,7 +89,7 @@ def test_nic_categories(simulation): with pytest.raises(Exception): # should throw an error - NicObservation( + NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], low_nmne_threshold=3, med_nmne_threshold=9, diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index dce05b6a..2926ffa6 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -4,7 +4,7 @@ from uuid import uuid4 import pytest from gymnasium import spaces -from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.game.agent.observations.host_observations import HostObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation @@ -23,7 +23,7 @@ def test_node_observation(simulation): """Test a Node observation.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") - node_obs = NodeObservation(where=["network", "nodes", pc.hostname]) + node_obs = HostObservation(where=["network", "nodes", pc.hostname]) assert node_obs.space["operating_status"] == spaces.Discrete(5) diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 9efc70f7..1578305b 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -1,4 +1,4 @@ -from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.agent.observations.nic_observations import NICObservation from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.sim_container import Simulation @@ -141,9 +141,9 @@ def test_describe_state_nmne(uc2_network): def test_capture_nmne_observations(uc2_network): """ - Tests the NicObservation class's functionality within a simulated network environment. + Tests the NICObservation class's functionality within a simulated network environment. - This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the + This test ensures the observation space, as defined by instances of NICObservation, accurately reflects the number of MNEs detected based on network activities over multiple iterations. The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update @@ -168,8 +168,8 @@ def test_capture_nmne_observations(uc2_network): set_nmne_config(nmne_config) # Define observations for the NICs of the database and web servers - db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1]) - web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1]) + db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1]) + web_server_nic_obs = NICObservation(where=["network", "nodes", "web_server", "NICs", 1]) # Iterate through a set of test cases to simulate multiple DELETE queries for i in range(0, 20): diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py index c556cfad..7eacb30d 100644 --- a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -1,6 +1,5 @@ from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations.observation_manager import ObservationManager -from primaite.game.agent.observations.observations import ICSObservation +from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager from primaite.game.agent.rewards import RewardFunction from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent @@ -52,7 +51,7 @@ def test_probabilistic_agent(): 2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}}, }, ) - observation_space = ObservationManager(ICSObservation()) + observation_space = ObservationManager(NestedObservation(components={})) reward_function = RewardFunction() pa = ProbabilisticAgent( From 8da53db82224c0d9543e74ca2825030125a411b9 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 31 Mar 2024 23:20:48 +0100 Subject: [PATCH 11/16] #2417 Finalise parsing of observation space --- .../_package_data/data_manipulation.yaml | 172 +++++++----------- .../game/agent/observations/__init__.py | 12 ++ .../agent/observations/host_observations.py | 17 +- .../agent/observations/link_observation.py | 2 +- .../agent/observations/node_observations.py | 59 ++++-- .../agent/observations/observation_manager.py | 31 +++- 6 files changed, 167 insertions(+), 126 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 06028ee1..d810e58a 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -41,8 +41,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -91,8 +90,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -141,10 +139,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -177,102 +172,73 @@ agents: type: ProxyAgent observation_space: - - type: NODES - label: NODES # What is the dictionary key called - options: - hosts: - - hostname: domain_controller - - hostname: web_server - - hostname: database_server - - hostname: backup_server - - hostname: security_suite - - hostname: client_1 - - hostname: client_2 - routers: - - hostname: router_1 - firewalls: {} - - num_host_services: 1 - num_host_applications: 0 - num_host_folders: 1 - num_host_files: 1 - num_host_network_interfaces: 2 - num_router_ports: 4 - num_acl_rules: 10 - num_firewall_ports: 4 - firewalls_internal_inbound_acl: true - firewalls_internal_outbound_acl: true - firewalls_dmz_inbound_acl: true - firewalls_dmz_outbound_acl: true - firewalls_external_inbound_acl: true - firewalls_external_outbound_acl: true - - type: LINKS - label: "LINKS" - options: - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - - observation_space: - type: UC2BlueObservation + type: CUSTOM options: - nodes: - - node_hostname: domain_controller - services: - - service_name: DNSServer - - node_hostname: web_server - services: - - service_name: WebServer - - node_hostname: database_server - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/src/primaite/game/agent/observations/__init__.py b/src/primaite/game/agent/observations/__init__.py index e69de29b..b9d97ae6 100644 --- a/src/primaite/game/agent/observations/__init__.py +++ b/src/primaite/game/agent/observations/__init__.py @@ -0,0 +1,12 @@ +# flake8: noqa +from primaite.game.agent.observations.acl_observation import ACLObservation +from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation +from primaite.game.agent.observations.firewall_observation import FirewallObservation +from primaite.game.agent.observations.host_observations import HostObservation +from primaite.game.agent.observations.link_observation import LinkObservation, LinksObservation +from primaite.game.agent.observations.nic_observations import NICObservation, PortObservation +from primaite.game.agent.observations.node_observations import NodesObservation +from primaite.game.agent.observations.observation_manager import NestedObservation, NullObservation, ObservationManager +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.router_observation import RouterObservation +from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index d71583b3..3ee5f2c7 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -94,6 +94,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): """ self.where: WhereType = where + self.include_num_access = include_num_access + # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services while len(self.services) < num_services: @@ -135,9 +137,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, "operating_status": 0, - "num_file_creations": 0, - "num_file_deletions": 0, } + if self.include_num_access: + self.default_observation["num_file_creations"] = 0 + self.default_observation["num_file_deletions"] = 0 def observe(self, state: Dict) -> ObsType: """ @@ -160,8 +163,9 @@ class HostObservation(AbstractObservation, identifier="HOST"): obs["NICS"] = { i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) } - obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] - obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_num_access: + obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] + obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] return obs @property @@ -180,9 +184,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): "NICS": spaces.Dict( {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} ), - "num_file_creations": spaces.Discrete(4), - "num_file_deletions": spaces.Discrete(4), } + if self.include_num_access: + shape["num_file_creations"] = spaces.Discrete(4) + shape["num_file_deletions"] = spaces.Discrete(4) return spaces.Dict(shape) @classmethod diff --git a/src/primaite/game/agent/observations/link_observation.py b/src/primaite/game/agent/observations/link_observation.py index f810bb36..be08657d 100644 --- a/src/primaite/game/agent/observations/link_observation.py +++ b/src/primaite/game/agent/observations/link_observation.py @@ -132,7 +132,7 @@ class LinksObservation(AbstractObservation, identifier="LINKS"): :return: Gymnasium space representing the observation space for multiple links. :rtype: spaces.Space """ - return {i + 1: l.space for i, l in enumerate(self.links)} + return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)}) @classmethod def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation: diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 7d227bb7..dce33a04 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Dict, List, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType +from pydantic import model_validator from primaite import getLogger from primaite.game.agent.observations.firewall_observation import FirewallObservation @@ -28,33 +29,63 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """List of configurations for router observations.""" firewalls: List[FirewallObservation.ConfigSchema] = [] """List of configurations for firewall observations.""" - num_services: int + num_services: Optional[int] = None """Number of services.""" - num_applications: int + num_applications: Optional[int] = None """Number of applications.""" - num_folders: int + num_folders: Optional[int] = None """Number of folders.""" - num_files: int + num_files: Optional[int] = None """Number of files.""" - num_nics: int + num_nics: Optional[int] = None """Number of network interface cards (NICs).""" - include_nmne: bool + include_nmne: Optional[bool] = None """Flag to include nmne.""" - include_num_access: bool + include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" - num_ports: int + num_ports: Optional[int] = None """Number of ports.""" - ip_list: List[str] + ip_list: Optional[List[str]] = None """List of IP addresses for encoding ACLs.""" - wildcard_list: List[str] + wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: List[int] + port_list: Optional[List[int]] = None """List of ports for encoding ACLs.""" - protocol_list: List[str] + protocol_list: Optional[List[str]] = None """List of protocols for encoding ACLs.""" - num_rules: int + num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + @model_validator(mode="after") + def force_optional_fields(self) -> NodesObservation.ConfigSchema: + """Check that options are specified only if they are needed for the nodes that are part of the config.""" + # check for hosts: + host_fields = ( + self.num_services, + self.num_applications, + self.num_folders, + self.num_files, + self.num_nics, + self.include_nmne, + self.include_num_access, + ) + router_fields = ( + self.num_ports, + self.ip_list, + self.wildcard_list, + self.port_list, + self.protocol_list, + self.num_rules, + ) + firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules) + if len(self.hosts) > 0 and any([x is None for x in host_fields]): + raise ValueError("Configuration error: Host observation options were not fully specified.") + if len(self.routers) > 0 and any([x is None for x in router_fields]): + raise ValueError("Configuration error: Router observation options were not fully specified.") + if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]): + raise ValueError("Configuration error: Firewall observation options were not fully specified.") + return self + def __init__( self, where: WhereType, diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 84311984..3703fa1c 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -120,6 +120,30 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): return cls(components=instances) +class NullObservation(AbstractObservation, identifier="NONE"): + """Empty observation that acts as a placeholder.""" + + def __init__(self) -> None: + """Initialise the empty observation.""" + self.default_observation = 0 + + def observe(self, state: Dict) -> Any: + """Simply return 0.""" + return 0 + + @property + def space(self) -> spaces.Space: + """Essentially empty space.""" + return spaces.Discrete(1) + + @classmethod + def from_config( + cls, config: NullObservation.ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] + ) -> NullObservation: + """Instantiate a NullObservation. Accepts parameters to comply with API.""" + return cls() + + class ObservationManager: """ Manage the observations of an Agent. @@ -156,7 +180,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": + def from_config(cls, config: Optional[Dict], game: "PrimaiteGame") -> "ObservationManager": """ Create observation space from a config. @@ -168,6 +192,9 @@ class ObservationManager: :param game: Reference to the PrimaiteGame object that spawned this observation. :type game: PrimaiteGame """ + if config is None: + return cls(NullObservation()) + print(config) obs_type = config["type"] obs_class = AbstractObservation._registry[obs_type] observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game) From 0e0df1012fb20e0fe1b9f1963de2bbb74c03b0d7 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 31 Mar 2024 23:39:24 +0100 Subject: [PATCH 12/16] #2417 update observations init to autoimport all obs types --- src/primaite/game/agent/observations/__init__.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/primaite/game/agent/observations/__init__.py b/src/primaite/game/agent/observations/__init__.py index b9d97ae6..15fdf7ed 100644 --- a/src/primaite/game/agent/observations/__init__.py +++ b/src/primaite/game/agent/observations/__init__.py @@ -1,4 +1,5 @@ # flake8: noqa +# Pre-import all the observations when we load up the observations module so that they can be resolved by the parser. from primaite.game.agent.observations.acl_observation import ACLObservation from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation from primaite.game.agent.observations.firewall_observation import FirewallObservation @@ -10,3 +11,10 @@ from primaite.game.agent.observations.observation_manager import NestedObservati from primaite.game.agent.observations.observations import AbstractObservation from primaite.game.agent.observations.router_observation import RouterObservation from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation + +# fmt: off +__all__ = [ + "ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation", + "LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation", + "ObservationManager", "ApplicationObservation", "ServiceObservation",] +# fmt: on From 0ba767d2a0988c404524a042d3d5a4396ac053d4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Apr 2024 00:54:55 +0100 Subject: [PATCH 13/16] #2417 update observation tests and make old tests pass --- .../_package_data/data_manipulation_marl.yaml | 251 +++++++++-------- .../agent/observations/acl_observation.py | 22 +- .../observations/file_system_observations.py | 16 +- .../agent/observations/host_observations.py | 52 ++-- .../agent/observations/nic_observations.py | 36 ++- .../agent/observations/router_observation.py | 13 +- .../assets/configs/bad_primaite_session.yaml | 130 ++++----- tests/assets/configs/basic_firewall.yaml | 3 +- .../configs/basic_switched_network.yaml | 3 +- tests/assets/configs/dmz_network.yaml | 3 +- .../configs/eval_only_primaite_session.yaml | 130 ++++----- tests/assets/configs/multi_agent_session.yaml | 252 ++++++++++-------- tests/assets/configs/shared_rewards.yaml | 131 ++++----- .../assets/configs/test_primaite_session.yaml | 132 ++++----- .../configs/train_only_primaite_session.yaml | 130 ++++----- .../test_primaite_session.py | 8 +- .../observations/test_acl_observations.py | 28 +- .../test_file_system_observations.py | 8 +- .../observations/test_nic_observations.py | 11 +- .../observations/test_node_observations.py | 27 +- .../game_layer/test_observations.py | 3 +- .../network/test_capture_nmne.py | 4 +- 22 files changed, 767 insertions(+), 626 deletions(-) diff --git a/src/primaite/config/_package_data/data_manipulation_marl.yaml b/src/primaite/config/_package_data/data_manipulation_marl.yaml index b632f626..3e95a6ee 100644 --- a/src/primaite/config/_package_data/data_manipulation_marl.yaml +++ b/src/primaite/config/_package_data/data_manipulation_marl.yaml @@ -40,8 +40,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -90,8 +89,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -140,10 +138,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -179,61 +174,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: DNSServer - - node_hostname: web_server - services: - - service_name: WebServer - - node_hostname: database_server - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: @@ -730,61 +737,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: DNSServer - - node_hostname: web_server - services: - - service_name: WebServer - - node_hostname: database_server - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index ac599ea0..fc603a8a 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -59,10 +59,10 @@ class ACLObservation(AbstractObservation, identifier="ACL"): """ self.where = where self.num_rules: int = num_rules - self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)} - self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)} - self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)} - self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)} + self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)} + self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)} + self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} + self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i + 1: { @@ -110,16 +110,16 @@ class ACLObservation(AbstractObservation, identifier="ACL"): } else: src_ip = rule_state["src_ip_address"] - src_node_id = self.ip_to_id.get(src_ip, 1) + src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip] dst_ip = rule_state["dst_ip_address"] - dst_node_ip = self.ip_to_id.get(dst_ip, 1) - src_wildcard = rule_state["source_wildcard_id"] + dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip] + src_wildcard = rule_state["src_wildcard_mask"] src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1) - dst_wildcard = rule_state["dest_wildcard_id"] + dst_wildcard = rule_state["dst_wildcard_mask"] dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1) - src_port = rule_state["source_port_id"] + src_port = rule_state["src_port"] src_port_id = self.port_to_id.get(src_port, 1) - dst_port = rule_state["dest_port_id"] + dst_port = rule_state["dst_port"] dst_port_id = self.port_to_id.get(dst_port, 1) protocol = rule_state["protocol"] protocol_id = self.protocol_to_id.get(protocol, 1) @@ -129,7 +129,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "source_ip_id": src_node_id, "source_wildcard_id": src_wildcard_id, "source_port_id": src_port_id, - "dest_ip_id": dst_node_ip, + "dest_ip_id": dst_node_id, "dest_wildcard_id": dst_wildcard_id, "dest_port_id": dst_port_id, "protocol_id": protocol_id, diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index a7c56a89..90bca35f 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -133,8 +133,9 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): self.default_observation = { "health_status": 0, - "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, } + if self.files: + self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)} def observe(self, state: Dict) -> ObsType: """ @@ -154,7 +155,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): obs = {} obs["health_status"] = health_status - obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} + if self.files: + obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} return obs @@ -166,12 +168,10 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): :return: Gymnasium space representing the observation space for folder status. :rtype: spaces.Space """ - return spaces.Dict( - { - "health_status": spaces.Discrete(6), - "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), - } - ) + shape = {"health_status": spaces.Discrete(6)} + if self.files: + shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}) + return spaces.Dict(shape) @classmethod def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation: diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 3ee5f2c7..8ea40be7 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -123,21 +123,27 @@ class HostObservation(AbstractObservation, identifier="HOST"): msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}" _LOGGER.warning(msg) - self.network_interfaces: List[NICObservation] = network_interfaces - while len(self.network_interfaces) < num_nics: - self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne)) - while len(self.network_interfaces) > num_nics: - truncated_nic = self.network_interfaces.pop() + self.nics: List[NICObservation] = network_interfaces + while len(self.nics) < num_nics: + self.nics.append(NICObservation(where=None, include_nmne=include_nmne)) + while len(self.nics) > num_nics: + truncated_nic = self.nics.pop() msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" _LOGGER.warning(msg) self.default_observation: ObsType = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, "operating_status": 0, } + if self.services: + self.default_observation["SERVICES"] = {i + 1: s.default_observation for i, s in enumerate(self.services)} + if self.applications: + self.default_observation["APPLICATIONS"] = { + i + 1: a.default_observation for i, a in enumerate(self.applications) + } + if self.folders: + self.default_observation["FOLDERS"] = {i + 1: f.default_observation for i, f in enumerate(self.folders)} + if self.nics: + self.default_observation["NICS"] = {i + 1: n.default_observation for i, n in enumerate(self.nics)} if self.include_num_access: self.default_observation["num_file_creations"] = 0 self.default_observation["num_file_deletions"] = 0 @@ -156,13 +162,15 @@ class HostObservation(AbstractObservation, identifier="HOST"): return self.default_observation obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} obs["operating_status"] = node_state["operating_state"] - obs["NICS"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } + if self.services: + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + if self.applications: + obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} + if self.folders: + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + if self.nics: + obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)} if self.include_num_access: obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] @@ -177,14 +185,16 @@ class HostObservation(AbstractObservation, identifier="HOST"): :rtype: spaces.Space """ shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), "operating_status": spaces.Discrete(5), - "NICS": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), } + if self.services: + shape["SERVICES"] = spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}) + if self.applications: + shape["APPLICATIONS"] = spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}) + if self.folders: + shape["FOLDERS"] = spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}) + if self.nics: + shape["NICS"] = spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}) if self.include_num_access: shape["num_file_creations"] = spaces.Discrete(4) shape["num_file_deletions"] = spaces.Discrete(4) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 19826f84..44cc7f8f 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -23,7 +23,11 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): include_nmne: Optional[bool] = None """Whether to include number of malicious network events (NMNE) in the observation.""" - def __init__(self, where: WhereType, include_nmne: bool) -> None: + def __init__( + self, + where: WhereType, + include_nmne: bool, + ) -> None: """ Initialise a network interface observation instance. @@ -40,6 +44,36 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): self.default_observation: ObsType = {"nic_status": 0} if self.include_nmne: self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}}) + self.nmne_inbound_last_step: int = 0 + self.nmne_outbound_last_step: int = 0 + + # TODO: allow these to be configured in yaml + self.high_nmne_threshold = 10 + self.med_nmne_threshold = 5 + self.low_nmne_threshold = 0 + + def _categorise_mne_count(self, nmne_count: int) -> int: + """ + Categorise the number of Malicious Network Events (NMNEs) into discrete bins. + + This helps in classifying the severity or volume of MNEs into manageable levels for the agent. + + Bins are defined as follows: + - 0: No MNEs detected (0 events). + - 1: Low number of MNEs (default 1-5 events). + - 2: Moderate number of MNEs (default 6-10 events). + - 3: High number of MNEs (default more than 10 events). + + :param nmne_count: Number of MNEs detected. + :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. + """ + if nmne_count > self.high_nmne_threshold: + return 3 + elif nmne_count > self.med_nmne_threshold: + return 2 + elif nmne_count > self.low_nmne_threshold: + return 1 + return 0 def observe(self, state: Dict) -> ObsType: """ diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index c2919770..a7879f09 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -74,9 +74,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): _LOGGER.warning(msg) self.default_observation = { - "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, "ACL": self.acl.default_observation, } + if self.ports: + self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)} def observe(self, state: Dict) -> ObsType: """ @@ -92,8 +93,9 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): return self.default_observation obs = {} - obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} obs["ACL"] = self.acl.observe(state) + if self.ports: + obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} return obs @property @@ -104,9 +106,10 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): :return: Gymnasium space representing the observation space for router status. :rtype: spaces.Space """ - return spaces.Dict( - {"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space} - ) + shape = {"ACL": self.acl.space} + if self.ports: + shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}) + return spaces.Dict(shape) @classmethod def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation: diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index e599ee7e..c613008e 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -22,8 +22,7 @@ agents: - ref: client_2_green_user team: GREEN type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -50,10 +49,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -86,63 +82,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: domain_controller_dns_server - - node_hostname: web_server - services: - - service_name: web_server_database_client - - node_hostname: database_server - services: - - service_name: database_service - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/tests/assets/configs/basic_firewall.yaml b/tests/assets/configs/basic_firewall.yaml index 9d7b34cb..5de704dc 100644 --- a/tests/assets/configs/basic_firewall.yaml +++ b/tests/assets/configs/basic_firewall.yaml @@ -41,8 +41,7 @@ agents: - ref: client_2_green_user team: GREEN type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index 9a0d5313..aab6b780 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -41,8 +41,7 @@ agents: - ref: client_2_green_user team: GREEN type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING diff --git a/tests/assets/configs/dmz_network.yaml b/tests/assets/configs/dmz_network.yaml index 95e09e16..076c174a 100644 --- a/tests/assets/configs/dmz_network.yaml +++ b/tests/assets/configs/dmz_network.yaml @@ -66,8 +66,7 @@ agents: - ref: client_1_green_user team: GREEN type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 9d1404d8..a4450264 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -26,8 +26,7 @@ agents: - ref: client_2_green_user team: GREEN type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -55,10 +54,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -90,63 +86,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: domain_controller_dns_server - - node_hostname: web_server - services: - - service_name: web_server_database_client - - node_hostname: database_server - services: - - service_name: database_service - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index acb62c96..8723e71a 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -32,8 +32,7 @@ agents: - ref: client_2_green_user team: GREEN type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -61,10 +60,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -97,63 +93,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: domain_controller_dns_server - - node_hostname: web_server - services: - - service_name: web_server_database_client - - node_hostname: database_server - services: - - service_name: database_service - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: @@ -541,63 +547,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: domain_controller_dns_server - - node_hostname: web_server - services: - - service_name: web_server_database_client - - node_hostname: database_server - services: - - service_name: database_service - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/tests/assets/configs/shared_rewards.yaml b/tests/assets/configs/shared_rewards.yaml index 10feba9d..9acf3ad5 100644 --- a/tests/assets/configs/shared_rewards.yaml +++ b/tests/assets/configs/shared_rewards.yaml @@ -41,8 +41,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -91,8 +90,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -141,10 +139,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -177,61 +172,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: DNSServer - - node_hostname: web_server - services: - - service_name: WebServer - - node_hostname: database_server - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index a8b33032..9391084a 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -33,8 +33,7 @@ agents: - ref: client_2_green_user team: GREEN type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -62,10 +61,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -98,65 +94,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: domain_controller_dns_server - - node_hostname: web_server - services: - - service_name: web_server_database_client - - node_hostname: database_server - services: - - service_name: database_service - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - # services: - # - service_name: backup_service - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index d0cbaab3..5e00928b 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -26,8 +26,7 @@ agents: - ref: client_2_green_user team: GREEN type: ProbabilisticAgent - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -62,10 +61,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -98,63 +94,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: domain_controller_dns_server - - node_hostname: web_server - services: - - service_name: web_server_database_client - - node_hostname: database_server - services: - - service_name: database_service - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index c45a4690..4e9ba723 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -11,8 +11,9 @@ MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml" MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml" -# @pytest.mark.skip(reason="no way of currently testing this") +@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.") class TestPrimaiteSession: + @pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.") @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) def test_creating_session(self, temp_primaite_session): """Check that creating a session from config works.""" @@ -51,6 +52,7 @@ class TestPrimaiteSession: assert checkpoint_2.exists() assert not checkpoint_3.exists() + @pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.") @pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True) def test_training_only_session(self, temp_primaite_session): """Check that you can run a training-only session.""" @@ -59,6 +61,7 @@ class TestPrimaiteSession: session.start_session() # TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved? + @pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.") @pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True) def test_eval_only_session(self, temp_primaite_session): """Check that you can load a model and run an eval-only session.""" @@ -67,6 +70,7 @@ class TestPrimaiteSession: session.start_session() # TODO: include checks that the model was loaded and that the eval-only session ran + @pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.") @pytest.mark.skip(reason="Slow, reenable later") @pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True) def test_multi_agent_session(self, temp_primaite_session): @@ -74,10 +78,12 @@ class TestPrimaiteSession: with temp_primaite_session as session: session.start_session() + @pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.") def test_error_thrown_on_bad_configuration(self): with pytest.raises(pydantic.ValidationError): session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH) + @pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.") @pytest.mark.skip( reason="Currently software cannot be dynamically created/destroyed during simulation. Therefore, " "reset doesn't implement software restore." diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py index d0710f5f..5aa2ec2a 100644 --- a/tests/integration_tests/game_layer/observations/test_acl_observations.py +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -36,9 +36,11 @@ def test_acl_observations(simulation): acl_obs = ACLObservation( where=["network", "nodes", router.hostname, "acl", "acl"], - node_ip_to_id={}, - ports=["NTP", "HTTP", "POSTGRES_SERVER"], - protocols=["TCP", "UDP", "ICMP"], + ip_list=[], + port_list=["NTP", "HTTP", "POSTGRES_SERVER"], + protocol_list=["TCP", "UDP", "ICMP"], + num_rules=10, + wildcard_list=[], ) observation_space = acl_obs.observe(simulation.describe_state()) @@ -46,11 +48,11 @@ def test_acl_observations(simulation): rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1) assert rule_obs.get("permission") == 1 # permit = 1 deny = 2 - assert rule_obs.get("source_node_id") == 1 # applies to all source nodes - assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes - assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs) - assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2 - assert rule_obs.get("protocol") == 1 # 1 = No Protocol + assert rule_obs.get("source_ip_id") == 1 # applies to all source nodes + assert rule_obs.get("dest_ip_id") == 1 # applies to all destination nodes + assert rule_obs.get("source_port_id") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs) + assert rule_obs.get("dest_port_id") == 2 # NTP port is mapped to value 2 + assert rule_obs.get("protocol_id") == 1 # 1 = No Protocol router.acl.remove_rule(1) @@ -59,8 +61,8 @@ def test_acl_observations(simulation): rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP assert rule_obs.get("position") == 0 assert rule_obs.get("permission") == 0 - assert rule_obs.get("source_node_id") == 0 - assert rule_obs.get("dest_node_id") == 0 - assert rule_obs.get("source_port") == 0 - assert rule_obs.get("dest_port") == 0 - assert rule_obs.get("protocol") == 0 + assert rule_obs.get("source_ip_id") == 0 + assert rule_obs.get("dest_ip_id") == 0 + assert rule_obs.get("source_port_id") == 0 + assert rule_obs.get("dest_port_id") == 0 + assert rule_obs.get("protocol_id") == 0 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index 35bb95fd..af5e9650 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -23,7 +23,8 @@ def test_file_observation(simulation): file = pc.file_system.create_file(file_name="dog.png") dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"] + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, ) assert dog_file_obs.space["health_status"] == spaces.Discrete(6) @@ -49,7 +50,10 @@ def test_folder_observation(simulation): file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") root_folder_obs = FolderObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"] + where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"], + include_num_access=False, + num_files=1, + files=[], ) assert root_folder_obs.space["health_status"] == spaces.Discrete(6) diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py index bc4261ce..66b7df55 100644 --- a/tests/integration_tests/game_layer/observations/test_nic_observations.py +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -40,7 +40,7 @@ def test_nic(simulation): nic: NIC = pc.network_interface[1] - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) assert nic_obs.space["nic_status"] == spaces.Discrete(3) assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) @@ -61,17 +61,22 @@ def test_nic_categories(simulation): """Test the NIC observation nmne count categories.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") - nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True) assert nic_obs.high_nmne_threshold == 10 # default assert nic_obs.med_nmne_threshold == 5 # default assert nic_obs.low_nmne_threshold == 0 # default + +@pytest.mark.skip(reason="Feature not implemented yet") +def test_config_nic_categories(simulation): + pc: Computer = simulation.network.get_node_by_hostname("client_1") nic_obs = NICObservation( where=["network", "nodes", pc.hostname, "NICs", 1], low_nmne_threshold=3, med_nmne_threshold=6, high_nmne_threshold=9, + include_nmne=True, ) assert nic_obs.high_nmne_threshold == 9 @@ -85,6 +90,7 @@ def test_nic_categories(simulation): low_nmne_threshold=9, med_nmne_threshold=6, high_nmne_threshold=9, + include_nmne=True, ) with pytest.raises(Exception): @@ -94,4 +100,5 @@ def test_nic_categories(simulation): low_nmne_threshold=3, med_nmne_threshold=9, high_nmne_threshold=9, + include_nmne=True, ) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index 2926ffa6..458cf0ab 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -19,15 +19,28 @@ def simulation(example_network) -> Simulation: return sim -def test_node_observation(simulation): - """Test a Node observation.""" +def test_host_observation(simulation): + """Test a Host observation.""" pc: Computer = simulation.network.get_node_by_hostname("client_1") - node_obs = HostObservation(where=["network", "nodes", pc.hostname]) + host_obs = HostObservation( + where=["network", "nodes", pc.hostname], + num_applications=0, + num_files=1, + num_folders=1, + num_nics=2, + num_services=1, + include_num_access=False, + include_nmne=False, + services=[], + applications=[], + folders=[], + network_interfaces=[], + ) - assert node_obs.space["operating_status"] == spaces.Discrete(5) + assert host_obs.space["operating_status"] == spaces.Discrete(5) - observation_state = node_obs.observe(simulation.describe_state()) + observation_state = host_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 1 # computer is on assert observation_state.get("SERVICES") is not None @@ -36,11 +49,11 @@ def test_node_observation(simulation): # turn off computer pc.power_off() - observation_state = node_obs.observe(simulation.describe_state()) + observation_state = host_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 4 # shutting down for i in range(pc.shut_down_duration + 1): pc.apply_timestep(i) - observation_state = node_obs.observe(simulation.describe_state()) + observation_state = host_obs.observe(simulation.describe_state()) assert observation_state.get("operating_status") == 2 diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index f52b52f7..0a34ab67 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -14,7 +14,8 @@ def test_file_observation(): state = sim.describe_state() dog_file_obs = FileObservation( - where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"] + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"], + include_num_access=False, ) assert dog_file_obs.observe(state) == {"health_status": 1} assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py index 1578305b..6601831f 100644 --- a/tests/integration_tests/network/test_capture_nmne.py +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -168,8 +168,8 @@ def test_capture_nmne_observations(uc2_network): set_nmne_config(nmne_config) # Define observations for the NICs of the database and web servers - db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1]) - web_server_nic_obs = NICObservation(where=["network", "nodes", "web_server", "NICs", 1]) + db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True) + web_server_nic_obs = NICObservation(where=["network", "nodes", "web_server", "NICs", 1], include_nmne=True) # Iterate through a set of test cases to simulate multiple DELETE queries for i in range(0, 20): From 709486d739b31ea474432bdcc9e5dc8f4b4d0bb6 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Apr 2024 16:06:12 +0100 Subject: [PATCH 14/16] #2417 test firewall and router obs --- .../agent/observations/acl_observation.py | 9 +- .../observations/firewall_observation.py | 96 +++++++------ .../observations/test_firewall_observation.py | 128 ++++++++++++++++++ .../observations/test_router_observation.py | 108 +++++++++++++++ 4 files changed, 292 insertions(+), 49 deletions(-) create mode 100644 tests/integration_tests/game_layer/observations/test_firewall_observation.py create mode 100644 tests/integration_tests/game_layer/observations/test_router_observation.py diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index fc603a8a..8b3d8ab5 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -64,8 +64,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { - i - + 1: { + i: { "position": i, "permission": 0, "source_ip_id": 0, @@ -76,7 +75,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_port_id": 0, "protocol_id": 0, } - for i in range(self.num_rules) + for i in range(1, self.num_rules + 1) } def observe(self, state: Dict) -> ObsType: @@ -98,7 +97,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): rule_state = acl_items[i] if rule_state is None: obs[i] = { - "position": i - 1, + "position": i, "permission": 0, "source_ip_id": 0, "source_wildcard_id": 0, @@ -124,7 +123,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): protocol = rule_state["protocol"] protocol_id = self.protocol_to_id.get(protocol, 1) obs[i] = { - "position": i - 1, + "position": i, "permission": rule_state["action"], "source_ip_id": src_node_id, "source_wildcard_id": src_wildcard_id, diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 69398d96..ab48e606 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -63,12 +63,12 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): self.where: WhereType = where self.ports: List[PortObservation] = [ - PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3) + PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3) ] # TODO: check what the port nums are for firewall. self.internal_inbound_acl = ACLObservation( - where=self.where + ["acl", "internal", "inbound"], + where=self.where + ["internal_inbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -76,7 +76,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.internal_outbound_acl = ACLObservation( - where=self.where + ["acl", "internal", "outbound"], + where=self.where + ["internal_outbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -84,7 +84,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.dmz_inbound_acl = ACLObservation( - where=self.where + ["acl", "dmz", "inbound"], + where=self.where + ["dmz_inbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -92,7 +92,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.dmz_outbound_acl = ACLObservation( - where=self.where + ["acl", "dmz", "outbound"], + where=self.where + ["dmz_outbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -100,7 +100,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.external_inbound_acl = ACLObservation( - where=self.where + ["acl", "external", "inbound"], + where=self.where + ["external_inbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -108,7 +108,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): protocol_list=protocol_list, ) self.external_outbound_acl = ACLObservation( - where=self.where + ["acl", "external", "outbound"], + where=self.where + ["external_outbound_acl", "acl"], num_rules=num_rules, ip_list=ip_list, wildcard_list=wildcard_list, @@ -118,17 +118,19 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): self.default_observation = { "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.default_observation, - "OUTBOUND": self.internal_outbound_acl.default_observation, - }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.default_observation, - "OUTBOUND": self.dmz_outbound_acl.default_observation, - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.default_observation, - "OUTBOUND": self.external_outbound_acl.default_observation, + "ACL": { + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.default_observation, + "OUTBOUND": self.internal_outbound_acl.default_observation, + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.default_observation, + "OUTBOUND": self.dmz_outbound_acl.default_observation, + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.default_observation, + "OUTBOUND": self.external_outbound_acl.default_observation, + }, }, } @@ -143,17 +145,19 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): """ obs = { "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.observe(state), - "OUTBOUND": self.internal_outbound_acl.observe(state), - }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.observe(state), - "OUTBOUND": self.dmz_outbound_acl.observe(state), - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.observe(state), - "OUTBOUND": self.external_outbound_acl.observe(state), + "ACL": { + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.observe(state), + "OUTBOUND": self.internal_outbound_acl.observe(state), + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.observe(state), + "OUTBOUND": self.dmz_outbound_acl.observe(state), + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.observe(state), + "OUTBOUND": self.external_outbound_acl.observe(state), + }, }, } return obs @@ -169,22 +173,26 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): space = spaces.Dict( { "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), - "INTERNAL": spaces.Dict( + "ACL": spaces.Dict( { - "INBOUND": self.internal_inbound_acl.space, - "OUTBOUND": self.internal_outbound_acl.space, - } - ), - "DMZ": spaces.Dict( - { - "INBOUND": self.dmz_inbound_acl.space, - "OUTBOUND": self.dmz_outbound_acl.space, - } - ), - "EXTERNAL": spaces.Dict( - { - "INBOUND": self.external_inbound_acl.space, - "OUTBOUND": self.external_outbound_acl.space, + "INTERNAL": spaces.Dict( + { + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + } + ), + "DMZ": spaces.Dict( + { + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + } + ), + "EXTERNAL": spaces.Dict( + { + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + } + ), } ), } diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py new file mode 100644 index 00000000..12a84e9a --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -0,0 +1,128 @@ +from primaite.game.agent.observations.firewall_observation import FirewallObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port + + +def check_default_rules(acl_obs): + assert len(acl_obs) == 7 + assert all(acl_obs[i]["position"] == i for i in range(1, 8)) + assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["source_port_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(1, 8)) + assert all(acl_obs[i]["protocol_id"] == 0 for i in range(1, 8)) + + +def test_firewall_observation(): + """Test adding/removing acl rules and enabling/disabling ports.""" + net = Network() + firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON) + firewall_observation = FirewallObservation( + where=[], + num_rules=7, + ip_list=["10.0.0.1", "10.0.0.2"], + wildcard_list=["0.0.0.255", "0.0.0.1"], + port_list=["HTTP", "DNS"], + protocol_list=["TCP"], + ) + + observation = firewall_observation.observe(firewall.describe_state()) + assert "ACL" in observation + assert "PORTS" in observation + assert "INTERNAL" in observation["ACL"] + assert "EXTERNAL" in observation["ACL"] + assert "DMZ" in observation["ACL"] + assert "INBOUND" in observation["ACL"]["INTERNAL"] + assert "OUTBOUND" in observation["ACL"]["INTERNAL"] + assert "INBOUND" in observation["ACL"]["EXTERNAL"] + assert "OUTBOUND" in observation["ACL"]["EXTERNAL"] + assert "INBOUND" in observation["ACL"]["DMZ"] + assert "OUTBOUND" in observation["ACL"]["DMZ"] + all_acls = ( + observation["ACL"]["INTERNAL"]["INBOUND"], + observation["ACL"]["INTERNAL"]["OUTBOUND"], + observation["ACL"]["EXTERNAL"]["INBOUND"], + observation["ACL"]["EXTERNAL"]["OUTBOUND"], + observation["ACL"]["DMZ"]["INBOUND"], + observation["ACL"]["DMZ"]["OUTBOUND"], + ) + for acl_obs in all_acls: + check_default_rules(acl_obs) + + # add a rule to the internal inbound and check that the observation is correct + firewall.internal_inbound_acl.add_rule( + action=ACLAction.DENY, + protocol=IPProtocol.TCP, + src_ip_address="10.0.0.1", + src_wildcard_mask="0.0.0.1", + dst_ip_address="10.0.0.2", + dst_wildcard_mask="0.0.0.1", + src_port=Port.HTTP, + dst_port=Port.HTTP, + position=5, + ) + + observation = firewall_observation.observe(firewall.describe_state()) + observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5] + assert observed_rule["position"] == 5 + assert observed_rule["permission"] == 2 + assert observed_rule["source_ip_id"] == 2 + assert observed_rule["source_wildcard_id"] == 3 + assert observed_rule["source_port_id"] == 2 + assert observed_rule["dest_ip_id"] == 3 + assert observed_rule["dest_wildcard_id"] == 3 + assert observed_rule["dest_port_id"] == 2 + assert observed_rule["protocol_id"] == 2 + + # check that none of the other acls have changed + all_acls = ( + observation["ACL"]["INTERNAL"]["OUTBOUND"], + observation["ACL"]["EXTERNAL"]["INBOUND"], + observation["ACL"]["EXTERNAL"]["OUTBOUND"], + observation["ACL"]["DMZ"]["INBOUND"], + observation["ACL"]["DMZ"]["OUTBOUND"], + ) + for acl_obs in all_acls: + check_default_rules(acl_obs) + + # remove the rule and check that the observation is correct + firewall.internal_inbound_acl.remove_rule(5) + observation = firewall_observation.observe(firewall.describe_state()) + all_acls = ( + observation["ACL"]["INTERNAL"]["INBOUND"], + observation["ACL"]["INTERNAL"]["OUTBOUND"], + observation["ACL"]["EXTERNAL"]["INBOUND"], + observation["ACL"]["EXTERNAL"]["OUTBOUND"], + observation["ACL"]["DMZ"]["INBOUND"], + observation["ACL"]["DMZ"]["OUTBOUND"], + ) + for acl_obs in all_acls: + check_default_rules(acl_obs) + + # check that there are three ports in the observation + assert len(observation["PORTS"]) == 3 + + # check that the ports are all disabled + assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) + + # connect a switch to the firewall and check that only the correct port is updated + switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + link = net.connect(firewall.network_interface[1], switch.network_interface[1]) + assert firewall.network_interface[1].enabled + observation = firewall_observation.observe(firewall.describe_state()) + assert observation["PORTS"][1]["operating_status"] == 1 + assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(2, 4)) + + # disable the port and check that the operating status is updated + firewall.network_interface[1].disable() + assert not firewall.network_interface[1].enabled + observation = firewall_observation.observe(firewall.describe_state()) + assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4)) diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py new file mode 100644 index 00000000..7db6a2c2 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -0,0 +1,108 @@ +from pprint import pprint + +from primaite.game.agent.observations.acl_observation import ACLObservation +from primaite.game.agent.observations.nic_observations import PortObservation +from primaite.game.agent.observations.router_observation import RouterObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation + + +def test_router_observation(): + """Test adding/removing acl rules and enabling/disabling ports.""" + net = Network() + router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON) + + ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)] + acl = ACLObservation( + where=["acl", "acl"], + num_rules=7, + ip_list=["10.0.0.1", "10.0.0.2"], + wildcard_list=["0.0.0.255", "0.0.0.1"], + port_list=["HTTP", "DNS"], + protocol_list=["TCP"], + ) + router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl) + + # Observe the state using the RouterObservation instance + observed_output = router_observation.observe(router.describe_state()) + + # Check that the right number of ports and acls are in the router observation + assert len(observed_output["PORTS"]) == 8 + assert len(observed_output["ACL"]) == 7 + + # Add an ACL rule to the router + router.acl.add_rule( + action=ACLAction.DENY, + protocol=IPProtocol.TCP, + src_ip_address="10.0.0.1", + src_wildcard_mask="0.0.0.1", + dst_ip_address="10.0.0.2", + dst_wildcard_mask="0.0.0.1", + src_port=Port.HTTP, + dst_port=Port.HTTP, + position=5, + ) + # Observe the state using the RouterObservation instance + observed_output = router_observation.observe(router.describe_state()) + observed_rule = observed_output["ACL"][5] + assert observed_rule["position"] == 5 + assert observed_rule["permission"] == 2 + assert observed_rule["source_ip_id"] == 2 + assert observed_rule["source_wildcard_id"] == 3 + assert observed_rule["source_port_id"] == 2 + assert observed_rule["dest_ip_id"] == 3 + assert observed_rule["dest_wildcard_id"] == 3 + assert observed_rule["dest_port_id"] == 2 + assert observed_rule["protocol_id"] == 2 + + # Add an ACL rule with ALL/NONE values and check that the observation is correct + router.acl.add_rule( + action=ACLAction.PERMIT, + protocol=None, + src_ip_address=None, + src_wildcard_mask=None, + dst_ip_address=None, + dst_wildcard_mask=None, + src_port=None, + dst_port=None, + position=2, + ) + observed_output = router_observation.observe(router.describe_state()) + observed_rule = observed_output["ACL"][2] + assert observed_rule["position"] == 2 + assert observed_rule["permission"] == 1 + assert observed_rule["source_ip_id"] == 1 + assert observed_rule["source_wildcard_id"] == 1 + assert observed_rule["source_port_id"] == 1 + assert observed_rule["dest_ip_id"] == 1 + assert observed_rule["dest_wildcard_id"] == 1 + assert observed_rule["dest_port_id"] == 1 + assert observed_rule["protocol_id"] == 1 + + # Check that the router ports are all disabled + assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) + + # connect a switch to the router and check that only the correct port is updated + switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON) + link = net.connect(router.network_interface[1], switch.network_interface[1]) + assert router.network_interface[1].enabled + observed_output = router_observation.observe(router.describe_state()) + assert observed_output["PORTS"][1]["operating_status"] == 1 + assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(2, 6)) + + # disable the port and check that the operating status is updated + router.network_interface[1].disable() + assert not router.network_interface[1].enabled + observed_output = router_observation.observe(router.describe_state()) + assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6)) + + # Check that ports that are out of range are shown as unused + observed_output = router_observation.observe(router.describe_state()) + assert observed_output["PORTS"][6]["operating_status"] == 0 + assert observed_output["PORTS"][7]["operating_status"] == 0 + assert observed_output["PORTS"][8]["operating_status"] == 0 From 2513646205bbe608697d238c2d88380802d2bb0b Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Apr 2024 16:50:59 +0100 Subject: [PATCH 15/16] #2417 fix last observation tests --- .../agent/observations/acl_observation.py | 9 ++-- .../observations/test_firewall_observation.py | 4 +- .../observations/test_link_observations.py | 42 +++++++++++++++++++ .../observations/test_router_observation.py | 4 +- 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 8b3d8ab5..fc603a8a 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -64,7 +64,8 @@ class ACLObservation(AbstractObservation, identifier="ACL"): self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { - i: { + i + + 1: { "position": i, "permission": 0, "source_ip_id": 0, @@ -75,7 +76,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "dest_port_id": 0, "protocol_id": 0, } - for i in range(1, self.num_rules + 1) + for i in range(self.num_rules) } def observe(self, state: Dict) -> ObsType: @@ -97,7 +98,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): rule_state = acl_items[i] if rule_state is None: obs[i] = { - "position": i, + "position": i - 1, "permission": 0, "source_ip_id": 0, "source_wildcard_id": 0, @@ -123,7 +124,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): protocol = rule_state["protocol"] protocol_id = self.protocol_to_id.get(protocol, 1) obs[i] = { - "position": i, + "position": i - 1, "permission": rule_state["action"], "source_ip_id": src_node_id, "source_wildcard_id": src_wildcard_id, diff --git a/tests/integration_tests/game_layer/observations/test_firewall_observation.py b/tests/integration_tests/game_layer/observations/test_firewall_observation.py index 12a84e9a..959e30f6 100644 --- a/tests/integration_tests/game_layer/observations/test_firewall_observation.py +++ b/tests/integration_tests/game_layer/observations/test_firewall_observation.py @@ -10,7 +10,7 @@ from primaite.simulator.network.transmission.transport_layer import Port def check_default_rules(acl_obs): assert len(acl_obs) == 7 - assert all(acl_obs[i]["position"] == i for i in range(1, 8)) + assert all(acl_obs[i]["position"] == i - 1 for i in range(1, 8)) assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8)) assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8)) assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8)) @@ -72,7 +72,7 @@ def test_firewall_observation(): observation = firewall_observation.observe(firewall.describe_state()) observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5] - assert observed_rule["position"] == 5 + assert observed_rule["position"] == 4 assert observed_rule["permission"] == 2 assert observed_rule["source_ip_id"] == 2 assert observed_rule["source_wildcard_id"] == 3 diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index b13314f1..1a41cad4 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -4,8 +4,10 @@ from gymnasium import spaces from primaite.game.agent.observations.link_observation import LinkObservation from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.base import Link, Node +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.sim_container import Simulation @@ -71,3 +73,43 @@ def test_link_observation(simulation): observation_state = link_obs.observe(simulation.describe_state()) assert observation_state["PROTOCOLS"]["ALL"] == 1 + + +def test_link_observation_again(): + net = Network() + sim = Simulation(network=net) + switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON) + computer_1 = Computer( + hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0 + ) + computer_2 = Computer( + hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0 + ) + computer_1.power_on() + computer_2.power_on() + link_1 = net.connect(switch.network_interface[1], computer_1.network_interface[1]) + link_2 = net.connect(switch.network_interface[2], computer_2.network_interface[1]) + assert link_1 is not None + assert link_2 is not None + + link_1_observation = LinkObservation(where=["network", "links", link_1.uuid]) + link_2_observation = LinkObservation(where=["network", "links", link_2.uuid]) + + state = sim.describe_state() + link_1_obs = link_1_observation.observe(state) + link_2_obs = link_2_observation.observe(state) + assert "PROTOCOLS" in link_1_obs + assert "PROTOCOLS" in link_2_obs + assert "ALL" in link_1_obs["PROTOCOLS"] + assert "ALL" in link_2_obs["PROTOCOLS"] + assert link_1_obs["PROTOCOLS"]["ALL"] == 0 + assert link_2_obs["PROTOCOLS"]["ALL"] == 0 + + # Test that the link observation is updated when a packet is sent + computer_1.ping("10.0.0.2") + computer_2.ping("10.0.0.1") + state = sim.describe_state() + link_1_obs = link_1_observation.observe(state) + link_2_obs = link_2_observation.observe(state) + assert link_1_obs["PROTOCOLS"]["ALL"] > 0 + assert link_2_obs["PROTOCOLS"]["ALL"] > 0 diff --git a/tests/integration_tests/game_layer/observations/test_router_observation.py b/tests/integration_tests/game_layer/observations/test_router_observation.py index 7db6a2c2..55471676 100644 --- a/tests/integration_tests/game_layer/observations/test_router_observation.py +++ b/tests/integration_tests/game_layer/observations/test_router_observation.py @@ -50,7 +50,7 @@ def test_router_observation(): # Observe the state using the RouterObservation instance observed_output = router_observation.observe(router.describe_state()) observed_rule = observed_output["ACL"][5] - assert observed_rule["position"] == 5 + assert observed_rule["position"] == 4 assert observed_rule["permission"] == 2 assert observed_rule["source_ip_id"] == 2 assert observed_rule["source_wildcard_id"] == 3 @@ -74,7 +74,7 @@ def test_router_observation(): ) observed_output = router_observation.observe(router.describe_state()) observed_rule = observed_output["ACL"][2] - assert observed_rule["position"] == 2 + assert observed_rule["position"] == 1 assert observed_rule["permission"] == 1 assert observed_rule["source_ip_id"] == 1 assert observed_rule["source_wildcard_id"] == 1 From d2c7ae481c975fddac7af3c2ae900abc624ac2a4 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Mon, 1 Apr 2024 22:03:28 +0100 Subject: [PATCH 16/16] #2417 Add categorisation and updated new configs from merge --- .../observations/file_system_observations.py | 23 ++- .../observations/firewall_observation.py | 3 +- .../agent/observations/observation_manager.py | 8 +- .../observations/software_observation.py | 22 ++- .../configs/firewall_actions_network.yaml | 76 +++++++--- .../configs/test_application_install.yaml | 131 +++++++++--------- .../test_file_system_observations.py | 3 + .../observations/test_link_observations.py | 29 +--- .../game_layer/test_observations.py | 5 + 9 files changed, 188 insertions(+), 112 deletions(-) diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 90bca35f..9b9434af 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -43,6 +43,26 @@ class FileObservation(AbstractObservation, identifier="FILE"): if self.include_num_access: self.default_observation["num_access"] = 0 + # TODO: allow these to be configured in yaml + self.high_threshold = 10 + self.med_threshold = 5 + self.low_threshold = 0 + + def _categorise_num_access(self, num_access: int) -> int: + """ + Represent number of file accesses as a categorical variable. + + :param num_access: Number of file accesses. + :return: Bin number corresponding to the number of accesses. + """ + if num_access > self.high_threshold: + return 3 + elif num_access > self.med_threshold: + return 2 + elif num_access > self.low_threshold: + return 1 + return 0 + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -57,8 +77,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): return self.default_observation obs = {"health_status": file_state["visible_status"]} if self.include_num_access: - obs["num_access"] = file_state["num_access"] - # raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.") + obs["num_access"] = self._categorise_num_access(file_state["num_access"]) return obs @property diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index ab48e606..0c10a8d2 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -214,9 +214,8 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): :return: Constructed firewall observation instance. :rtype: FirewallObservation """ - where = parent_where + ["nodes", config.hostname] return cls( - where=where, + where=parent_where + ["nodes", config.hostname], ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 3703fa1c..1d428fa8 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -185,9 +185,11 @@ class ObservationManager: Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. - It should contain the key 'type' which selects which observation class to use (from a choice of: - UC2BlueObservation, UC2RedObservation, UC2GreenObservation) - The other key is 'options' which are passed to the constructor of the selected observation class. + If None, a blank observation space is created. + Otherwise, this must be a Dict with a type field and options field. + type: string that corresponds to one of the observation identifiers that are provided when subclassing + AbstractObservation + options: this must adhere to the chosen observation type's ConfigSchema nested class. :type config: Dict :param game: Reference to the PrimaiteGame object that spawned this observation. :type game: PrimaiteGame diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 40788760..2c4806d9 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -98,6 +98,26 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): self.where = where self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} + # TODO: allow these to be configured in yaml + self.high_threshold = 10 + self.med_threshold = 5 + self.low_threshold = 0 + + def _categorise_num_executions(self, num_executions: int) -> int: + """ + Represent number of file accesses as a categorical variable. + + :param num_access: Number of file accesses. + :return: Bin number corresponding to the number of accesses. + """ + if num_executions > self.high_threshold: + return 3 + elif num_executions > self.med_threshold: + return 2 + elif num_executions > self.low_threshold: + return 1 + return 0 + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. @@ -113,7 +133,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): return { "operating_status": application_state["operating_state"], "health_status": application_state["health_state_visible"], - "num_executions": application_state["num_executions"], + "num_executions": self._categorise_num_executions(application_state["num_executions"]), } @property diff --git a/tests/assets/configs/firewall_actions_network.yaml b/tests/assets/configs/firewall_actions_network.yaml index b7848c53..203ea3ea 100644 --- a/tests/assets/configs/firewall_actions_network.yaml +++ b/tests/assets/configs/firewall_actions_network.yaml @@ -64,25 +64,67 @@ agents: - ref: defender team: BLUE type: ProxyAgent + observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: client_1 - links: - - link_ref: client_1___switch_1 - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: client_1 - nic_num: 1 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: client_1 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.0.10 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - client_1___switch_1 + - type: "NONE" + label: ICS + options: {} + + # observation_space: + # type: UC2BlueObservation + # options: + # num_services_per_node: 1 + # num_folders_per_node: 1 + # num_files_per_folder: 1 + # num_nics_per_node: 2 + # nodes: + # - node_hostname: client_1 + # links: + # - link_ref: client_1___switch_1 + # acl: + # options: + # max_acl_rules: 10 + # router_hostname: router_1 + # ip_address_order: + # - node_hostname: client_1 + # nic_num: 1 + # ics: null action_space: action_list: - type: DONOTHING diff --git a/tests/assets/configs/test_application_install.yaml b/tests/assets/configs/test_application_install.yaml index b3fca4bc..ccd2228c 100644 --- a/tests/assets/configs/test_application_install.yaml +++ b/tests/assets/configs/test_application_install.yaml @@ -41,8 +41,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -91,8 +90,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -141,10 +139,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -177,61 +172,73 @@ agents: type: ProxyAgent observation_space: - type: UC2BlueObservation + type: CUSTOM options: - num_services_per_node: 1 - num_folders_per_node: 1 - num_files_per_folder: 1 - num_nics_per_node: 2 - nodes: - - node_hostname: domain_controller - services: - - service_name: DNSServer - - node_hostname: web_server - services: - - service_name: WebServer - - node_hostname: database_server - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py index af5e9650..cb83ac5e 100644 --- a/tests/integration_tests/game_layer/observations/test_file_system_observations.py +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -72,3 +72,6 @@ def test_folder_observation(simulation): observation_state = root_folder_obs.observe(simulation.describe_state()) assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too + + +# TODO: Add tests to check num access is correct. diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py index 1a41cad4..3eee72e8 100644 --- a/tests/integration_tests/game_layer/observations/test_link_observations.py +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -51,31 +51,8 @@ def simulation() -> Simulation: return sim -def test_link_observation(simulation): - """Test the link observation.""" - # get a link - link: Link = next(iter(simulation.network.links.values())) - - computer: Computer = simulation.network.get_node_by_hostname("computer") - server: Server = simulation.network.get_node_by_hostname("server") - - simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep - - link_obs = LinkObservation(where=["network", "links", link.uuid]) - - assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10 - - observation_state = link_obs.observe(simulation.describe_state()) - assert observation_state.get("PROTOCOLS") is not None - assert observation_state["PROTOCOLS"]["ALL"] == 0 - - computer.ping(server.network_interface.get(1).ip_address) - - observation_state = link_obs.observe(simulation.describe_state()) - assert observation_state["PROTOCOLS"]["ALL"] == 1 - - -def test_link_observation_again(): +def test_link_observation(): + """Check the shape and contents of the link observation.""" net = Network() sim = Simulation(network=net) switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON) @@ -102,6 +79,8 @@ def test_link_observation_again(): assert "PROTOCOLS" in link_2_obs assert "ALL" in link_1_obs["PROTOCOLS"] assert "ALL" in link_2_obs["PROTOCOLS"] + assert link_1_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) + assert link_2_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) assert link_1_obs["PROTOCOLS"]["ALL"] == 0 assert link_2_obs["PROTOCOLS"]["ALL"] == 0 diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index 0a34ab67..ed07e030 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -19,3 +19,8 @@ def test_file_observation(): ) assert dog_file_obs.observe(state) == {"health_status": 1} assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)}) + + +# TODO: +# def test_file_num_access(): +# ...