diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py index 70a83881..2148697b 100644 --- a/src/primaite/game/agent/observations/agent_observations.py +++ b/src/primaite/game/agent/observations/agent_observations.py @@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces -from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.game.agent.observations.host import NodeObservation from primaite.game.agent.observations.observations import ( AbstractObservation, AclObservation, @@ -136,53 +136,3 @@ class UC2BlueObservation(AbstractObservation): new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) return new - -class UC2RedObservation(AbstractObservation): - """Container for all observations used by the red agent in UC2.""" - - def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: - super().__init__() - self.where: Optional[List[str]] = where - self.nodes: List[NodeObservation] = nodes - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": - """ - Create UC2 red observation from a config. - - :param config: Dictionary containing the configuration for this UC2 red observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - """ - node_configs = config["nodes"] - nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] - return cls(nodes=nodes, where=["network"]) - - -class UC2GreenObservation(NullObservation): - """Green agent observation. As the green agent's actions don't depend on the observation, this is empty.""" - - pass diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 94f0974b..42bdb749 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,21 +1,473 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from __future__ import annotations +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union from gymnasium import spaces +from gymnasium.core import ObsType +from pydantic import BaseModel, ConfigDict from primaite import getLogger -from primaite.game.agent.observations.file_system_observations import FolderObservation -from primaite.game.agent.observations.nic_observations import NicObservation from primaite.game.agent.observations.observations import AbstractObservation -from primaite.game.agent.observations.software_observation import ServiceObservation +# from primaite.game.agent.observations.file_system_observations import FolderObservation +# from primaite.game.agent.observations.nic_observations import NicObservation +# from primaite.game.agent.observations.software_observation import ServiceObservation from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame +WhereType = Iterable[str | int] | None -class NodeObservation(AbstractObservation): +class ServiceObservation(AbstractObservation, identifier="SERVICE"): + class ConfigSchema(AbstractObservation.ConfigSchema): + service_name: str + + def __init__(self, where: WhereType)->None: + self.where = where + self.default_observation = {"operating_status": 0, "health_status": 0} + + def observe(self, state: Dict) -> Any: + service_state = access_from_nested_dict(state, self.where) + if service_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": service_state["operating_state"], + "health_status": service_state["health_state_visible"], + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + return cls(where=parent_where+["services", config.service_name]) + + +class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): + class ConfigSchema(AbstractObservation.ConfigSchema): + application_name: str + + def __init__(self, where: WhereType)->None: + self.where = where + self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} + + def observe(self, state: Dict) -> Any: + # raise NotImplementedError("TODO NUM EXECUTIONS NEEDS TO BE CONVERTED TO A CATEGORICAL") + application_state = access_from_nested_dict(state, self.where) + if application_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": application_state["operating_state"], + "health_status": application_state["health_state_visible"], + "num_executions": application_state["num_executions"], + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict({ + "operating_status": spaces.Discrete(7), + "health_status": spaces.Discrete(5), + "num_executions": spaces.Discrete(4) + }) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ApplicationObservation: + return cls(where=parent_where+["applications", config.application_name]) + + +class FileObservation(AbstractObservation, identifier="FILE"): + class ConfigSchema(AbstractObservation.ConfigSchema): + file_name: str + include_num_access : bool = False + + def __init__(self, where: WhereType, include_num_access: bool)->None: + self.where: WhereType = where + self.include_num_access :bool = include_num_access + + self.default_observation: ObsType = {"health_status": 0} + if self.include_num_access: + self.default_observation["num_access"] = 0 + + def observe(self, state: Dict) -> Any: + file_state = access_from_nested_dict(state, self.where) + if file_state is NOT_PRESENT_IN_STATE: + return self.default_observation + obs = {"health_status": file_state["visible_status"]} + if self.include_num_access: + obs["num_access"] = file_state["num_access"] + # raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.") + return obs + + @property + def space(self) -> spaces.Space: + space = {"health_status": spaces.Discrete(6)} + if self.include_num_access: + space["num_access"] = spaces.Discrete(4) + return spaces.Dict(space) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation: + return cls(where=parent_where+["files", config.file_name], include_num_access=config.include_num_access) + + +class FolderObservation(AbstractObservation, identifier="FOLDER"): + class ConfigSchema(AbstractObservation.ConfigSchema): + folder_name: str + files: List[FileObservation.ConfigSchema] = [] + num_files : int = 0 + include_num_access : bool = False + + def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None: + self.where: WhereType = where + + self.files: List[FileObservation] = files + while len(self.files) < num_files: + self.files.append(FileObservation(where=None,include_num_access=include_num_access)) + while len(self.files) > num_files: + truncated_file = self.files.pop() + msg = f"Too many files in folder observation. Truncating file {truncated_file}" + _LOGGER.warning(msg) + + self.default_observation = { + "health_status": 0, + "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, + } + + def observe(self, state: Dict) -> Any: + folder_state = access_from_nested_dict(state, self.where) + if folder_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + health_status = folder_state["health_status"] + + obs = {} + + obs["health_status"] = health_status + obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "health_status": spaces.Discrete(6), + "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), + } + ) + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation: + where = parent_where + ["folders", config.folder_name] + + #pass down shared/common config items + for file_config in config.files: + file_config.include_num_access = config.include_num_access + + files = [FileObservation.from_config(config=f, parent_where = where) for f in config.files] + return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) + + +class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): + class ConfigSchema(AbstractObservation.ConfigSchema): + nic_num: int + include_nmne: bool = False + + + def __init__(self, where: WhereType, include_nmne: bool)->None: + self.where = where + self.include_nmne : bool = include_nmne + + self.default_observation: ObsType = {"nic_status": 0} + if self.include_nmne: + self.default_observation.update({"NMNE":{"inbound":0, "outbound":0}}) + + def observe(self, state: Dict) -> Any: + # raise NotImplementedError("TODO: CATEGORISATION") + nic_state = access_from_nested_dict(state, self.where) + + if nic_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {"nic_status": 1 if nic_state["enabled"] else 2} + if self.include_nmne: + obs.update({"NMNE": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + self.nmne_inbound_last_step = inbound_count + self.nmne_outbound_last_step = outbound_count + return obs + + + @property + def space(self) -> spaces.Space: + space = spaces.Dict({"nic_status": spaces.Discrete(3)}) + + if self.include_nmne: + space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) + + return space + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne) + + +class HostObservation(AbstractObservation, identifier="HOST"): + class ConfigSchema(AbstractObservation.ConfigSchema): + hostname: str + services: List[ServiceObservation.ConfigSchema] = [] + applications: List[ApplicationObservation.ConfigSchema] = [] + folders: List[FolderObservation.ConfigSchema] = [] + network_interfaces: List[NICObservation.ConfigSchema] = [] + num_services: int + num_applications: int + num_folders: int + num_files: int + num_nics: int + include_nmne: bool + include_num_access: bool + + def __init__(self, + where: WhereType, + services:List[ServiceObservation], + applications:List[ApplicationObservation], + folders:List[FolderObservation], + network_interfaces:List[NICObservation], + num_services: int, + num_applications: int, + num_folders: int, + num_files: int, + num_nics: int, + include_nmne: bool, + include_num_access: bool + )->None: + + self.where : WhereType = where + + # ensure service list has length equal to num_services by truncating or padding + self.services: List[ServiceObservation] = services + while len(self.services) < num_services: + self.services.append(ServiceObservation(where=None)) + while len(self.services) > num_services: + truncated_service = self.services.pop() + msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" + _LOGGER.warning(msg) + + # ensure application list has length equal to num_applications by truncating or padding + self.applications: List[ApplicationObservation] = applications + while len(self.applications) < num_applications: + self.applications.append(ApplicationObservation(where=None)) + while len(self.applications) > num_applications: + truncated_application = self.applications.pop() + msg = f"Too many applications in Node observation space for node. Truncating application {truncated_application.where}" + _LOGGER.warning(msg) + + # ensure folder list has length equal to num_folders by truncating or padding + self.folders: List[FolderObservation] = folders + while len(self.folders) < num_folders: + self.folders.append(FolderObservation(where = None, files= [], num_files=num_files, include_num_access=include_num_access)) + while len(self.folders) > num_folders: + truncated_folder = self.folders.pop() + msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}" + _LOGGER.warning(msg) + + # ensure network_interface list has length equal to num_network_interfaces by truncating or padding + self.network_interfaces: List[NICObservation] = network_interfaces + while len(self.network_interfaces) < num_nics: + self.network_interfaces.append(NICObservation(where = None, include_nmne=include_nmne)) + while len(self.network_interfaces) > num_nics: + truncated_nic = self.network_interfaces.pop() + msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_folder.where}" + _LOGGER.warning(msg) + + self.default_observation: ObsType = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "operating_status": 0, + "num_file_creations": 0, + "num_file_deletions": 0, + } + + + def observe(self, state: Dict) -> Any: + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NICS"] = { + i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) + } + obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] + obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + return obs + + @property + def space(self) -> spaces.Space: + shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NICS": spaces.Dict( + {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} + ), + "num_file_creations" : spaces.Discrete(4), + "num_file_deletions" : spaces.Discrete(4), + } + return spaces.Dict(shape) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> ServiceObservation: + if parent_where is None: + where = ["network", "nodes", config.hostname] + else: + where = parent_where + ["nodes", config.hostname] + + #pass down shared/common config items + for folder_config in config.folders: + folder_config.include_num_access = config.include_num_access + folder_config.num_files = config.num_files + for nic_config in config.network_interfaces: + nic_config.include_nmne = config.include_nmne + + services = [ServiceObservation.from_config(config=c,parent_where=where) for c in config.services] + applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] + folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] + nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] + + return cls( + where = where, + services = services, + applications = applications, + folders = folders, + network_interfaces = nics, + num_services = config.num_services, + num_applications = config.num_applications, + num_folders = config.num_folders, + num_files = config.num_files, + num_nics = config.num_nics, + include_nmne = config.include_nmne, + include_num_access = config.include_num_access, + ) + + +class PortObservation(AbstractObservation, identifier="PORT"): + class ConfigSchema(AbstractObservation.ConfigSchema): + pass + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +class ACLObservation(AbstractObservation, identifier="ACL"): + class ConfigSchema(AbstractObservation.ConfigSchema): + pass + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +class RouterObservation(AbstractObservation, identifier="ROUTER"): + class ConfigSchema(AbstractObservation.ConfigSchema): + hostname: str + ports: List[PortObservation.ConfigSchema] + + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +class FirewallObservation(AbstractObservation, identifier="FIREWALL"): + class ConfigSchema(AbstractObservation.ConfigSchema): + hostname: str + ports: List[PortObservation.ConfigSchema] = [] + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +class NodesObservation(AbstractObservation, identifier="NODES"): + class ConfigSchema(AbstractObservation.ConfigSchema): + """Config""" + hosts: List[HostObservation.ConfigSchema] = [] + routers: List[RouterObservation.ConfigSchema] = [] + firewalls: List[FirewallObservation.ConfigSchema] = [] + num_services: int = 1 + + + def __init__(self, where: WhereType)->None: + pass + + def observe(self, state: Dict) -> Any: + pass + + @property + def space(self) -> spaces.Space: + pass + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + pass + +############################ OLD + +class NodeObservation(AbstractObservation, identifier= "OLD"): """Observation of a node in the network. Includes services, folders and NICs.""" def __init__( diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 400345fa..be90041e 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -2,11 +2,6 @@ from typing import Dict, TYPE_CHECKING from gymnasium.core import ObsType -from primaite.game.agent.observations.agent_observations import ( - UC2BlueObservation, - UC2GreenObservation, - UC2RedObservation, -) from primaite.game.agent.observations.observations import AbstractObservation if TYPE_CHECKING: @@ -63,11 +58,10 @@ class ObservationManager: :param game: Reference to the PrimaiteGame object that spawned this observation. :type game: PrimaiteGame """ - if config["type"] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) - else: - raise ValueError("Observation space type invalid") + + for obs_cfg in config: + obs_type = obs_cfg['type'] + obs_class = AbstractObservation._registry[obs_type] + observation = obs_class.from_config(obs_class.ConfigSchema(**obs_cfg['options'])) + obs_manager = cls(observation) + return obs_manager \ No newline at end of file diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 6236b00d..dc41e8e5 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,9 +1,10 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Type from gymnasium import spaces +from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE @@ -17,6 +18,28 @@ if TYPE_CHECKING: class AbstractObservation(ABC): """Abstract class for an observation space component.""" + class ConfigSchema(ABC, BaseModel): + model_config = ConfigDict(extra="forbid") + + _registry: Dict[str, Type["AbstractObservation"]] = {} + """Registry of observation components, with their name as key. + + Automatically populated when subclasses are defined. Used for defining from_config. + """ + + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: + """ + Register an observation type. + + :param identifier: Identifier used to uniquely specify observation component types. + :type identifier: str + :raises ValueError: When attempting to create a component with a name that is already in use. + """ + super().__init_subclass__(**kwargs) + if identifier in cls._registry: + raise ValueError(f"Duplicate observation component type {identifier}") + cls._registry[identifier] = cls + @abstractmethod def observe(self, state: Dict) -> Any: """ @@ -36,274 +59,271 @@ class AbstractObservation(ABC): pass @classmethod - @abstractmethod - def from_config(cls, config: Dict, game: "PrimaiteGame"): - """Create this observation space component form a serialised format. - - The `game` parameter is for a the PrimaiteGame object that spawns this component. - """ - pass + def from_config(cls, cfg: Dict) -> "AbstractObservation": + """Create this observation space component form a serialised format.""" + ObservationType = cls._registry[cfg['type']] + return ObservationType.from_config(cfg=cfg) -class LinkObservation(AbstractObservation): - """Observation of a link in the network.""" +# class LinkObservation(AbstractObservation): +# """Observation of a link in the network.""" - default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} - "Default observation is what should be returned when the link doesn't exist." +# default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} +# "Default observation is what should be returned when the link doesn't exist." - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise link observation. +# def __init__(self, where: Optional[Tuple[str]] = None) -> None: +# """Initialise link observation. - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. +# :param where: Store information about where in the simulation state dictionary to find the relevant information. +# Optional. If None, this corresponds that the file does not exist and the observation will be populated with +# zeroes. - A typical location for a service looks like this: - `['network','nodes',,'servics', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where +# A typical location for a service looks like this: +# `['network','nodes',,'servics', ]` +# :type where: Optional[List[str]] +# """ +# super().__init__() +# self.where: Optional[Tuple[str]] = where - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. +# def observe(self, state: Dict) -> Dict: +# """Generate observation based on the current state of the simulation. - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation +# :param state: Simulation state dictionary +# :type state: Dict +# :return: Observation +# :rtype: Dict +# """ +# if self.where is None: +# return self.default_observation - link_state = access_from_nested_dict(state, self.where) - if link_state is NOT_PRESENT_IN_STATE: - return self.default_observation +# link_state = access_from_nested_dict(state, self.where) +# if link_state is NOT_PRESENT_IN_STATE: +# return self.default_observation - bandwidth = link_state["bandwidth"] - load = link_state["current_load"] - if load == 0: - utilisation_category = 0 - else: - utilisation_fraction = load / bandwidth - # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% - utilisation_category = int(utilisation_fraction * 9) + 1 +# bandwidth = link_state["bandwidth"] +# load = link_state["current_load"] +# if load == 0: +# utilisation_category = 0 +# else: +# utilisation_fraction = load / bandwidth +# # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% +# utilisation_category = int(utilisation_fraction * 9) + 1 - # TODO: once the links support separte load per protocol, this needs amendment to reflect that. - return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} +# # TODO: once the links support separte load per protocol, this needs amendment to reflect that. +# return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. +# @property +# def space(self) -> spaces.Space: +# """Gymnasium space object describing the observation space shape. - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) +# :return: Gymnasium space +# :rtype: spaces.Space +# """ +# return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": - """Create link observation from a config. +# @classmethod +# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": +# """Create link observation from a config. - :param config: Dictionary containing the configuration for this link observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed link observation - :rtype: LinkObservation - """ - return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) +# :param config: Dictionary containing the configuration for this link observation. +# :type config: Dict +# :param game: Reference to the PrimaiteGame object that spawned this observation. +# :type game: PrimaiteGame +# :return: Constructed link observation +# :rtype: LinkObservation +# """ +# return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) -class AclObservation(AbstractObservation): - """Observation of an Access Control List (ACL) in the network.""" +# class AclObservation(AbstractObservation): +# """Observation of an Access Control List (ACL) in the network.""" - # TODO: should where be optional, and we can use where=None to pad the observation space? - # definitely the current approach does not support tracking files that aren't specified by name, for example - # if a file is created at runtime, we have currently got no way of telling the observation space to track it. - # this needs adding, but not for the MVP. - def __init__( - self, - node_ip_to_id: Dict[str, int], - ports: List[int], - protocols: List[str], - where: Optional[Tuple[str]] = None, - num_rules: int = 10, - ) -> None: - """Initialise ACL observation. +# # TODO: should where be optional, and we can use where=None to pad the observation space? +# # definitely the current approach does not support tracking files that aren't specified by name, for example +# # if a file is created at runtime, we have currently got no way of telling the observation space to track it. +# # this needs adding, but not for the MVP. +# def __init__( +# self, +# node_ip_to_id: Dict[str, int], +# ports: List[int], +# protocols: List[str], +# where: Optional[Tuple[str]] = None, +# num_rules: int = 10, +# ) -> None: +# """Initialise ACL observation. - :param node_ip_to_id: Mapping between IP address and ID. - :type node_ip_to_id: Dict[str, int] - :param ports: List of ports which are part of the game that define the ordering when converting to an ID - :type ports: List[int] - :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID - :type protocols: list[str] - :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical - example may look like this: - ['network','nodes',,'acl','acl'] - :type where: Optional[Tuple[str]], optional - :param num_rules: , defaults to 10 - :type num_rules: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.num_rules: int = num_rules - self.node_to_id: Dict[str, int] = node_ip_to_id - "List of node IP addresses, order in this list determines how they are converted to an ID" - self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} - "List of ports which are part of the game that define the ordering when converting to an ID" - self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} - "List of protocols which are part of the game, defines ordering when converting to an ID" - self.default_observation: Dict = { - i - + 1: { - "position": i, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - for i in range(self.num_rules) - } +# :param node_ip_to_id: Mapping between IP address and ID. +# :type node_ip_to_id: Dict[str, int] +# :param ports: List of ports which are part of the game that define the ordering when converting to an ID +# :type ports: List[int] +# :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID +# :type protocols: list[str] +# :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical +# example may look like this: +# ['network','nodes',,'acl','acl'] +# :type where: Optional[Tuple[str]], optional +# :param num_rules: , defaults to 10 +# :type num_rules: int, optional +# """ +# super().__init__() +# self.where: Optional[Tuple[str]] = where +# self.num_rules: int = num_rules +# self.node_to_id: Dict[str, int] = node_ip_to_id +# "List of node IP addresses, order in this list determines how they are converted to an ID" +# self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} +# "List of ports which are part of the game that define the ordering when converting to an ID" +# self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} +# "List of protocols which are part of the game, defines ordering when converting to an ID" +# self.default_observation: Dict = { +# i +# + 1: { +# "position": i, +# "permission": 0, +# "source_node_id": 0, +# "source_port": 0, +# "dest_node_id": 0, +# "dest_port": 0, +# "protocol": 0, +# } +# for i in range(self.num_rules) +# } - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. +# def observe(self, state: Dict) -> Dict: +# """Generate observation based on the current state of the simulation. - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - acl_state: Dict = access_from_nested_dict(state, self.where) - if acl_state is NOT_PRESENT_IN_STATE: - return self.default_observation +# :param state: Simulation state dictionary +# :type state: Dict +# :return: Observation +# :rtype: Dict +# """ +# if self.where is None: +# return self.default_observation +# acl_state: Dict = access_from_nested_dict(state, self.where) +# if acl_state is NOT_PRESENT_IN_STATE: +# return self.default_observation - # TODO: what if the ACL has more rules than num of max rules for obs space - obs = {} - acl_items = dict(acl_state.items()) - i = 1 # don't show rule 0 for compatibility reasons. - while i < self.num_rules + 1: - rule_state = acl_items[i] - if rule_state is None: - obs[i] = { - "position": i - 1, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - else: - src_ip = rule_state["src_ip_address"] - src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] - dst_ip = rule_state["dst_ip_address"] - dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] - src_port = rule_state["src_port"] - src_port_id = 1 if src_port is None else self.port_to_id[src_port] - dst_port = rule_state["dst_port"] - dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] - protocol = rule_state["protocol"] - protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] - obs[i] = { - "position": i - 1, - "permission": rule_state["action"], - "source_node_id": src_node_id, - "source_port": src_port_id, - "dest_node_id": dst_node_ip, - "dest_port": dst_port_id, - "protocol": protocol_id, - } - i += 1 - return obs +# # TODO: what if the ACL has more rules than num of max rules for obs space +# obs = {} +# acl_items = dict(acl_state.items()) +# i = 1 # don't show rule 0 for compatibility reasons. +# while i < self.num_rules + 1: +# rule_state = acl_items[i] +# if rule_state is None: +# obs[i] = { +# "position": i - 1, +# "permission": 0, +# "source_node_id": 0, +# "source_port": 0, +# "dest_node_id": 0, +# "dest_port": 0, +# "protocol": 0, +# } +# else: +# src_ip = rule_state["src_ip_address"] +# src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] +# dst_ip = rule_state["dst_ip_address"] +# dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] +# src_port = rule_state["src_port"] +# src_port_id = 1 if src_port is None else self.port_to_id[src_port] +# dst_port = rule_state["dst_port"] +# dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] +# protocol = rule_state["protocol"] +# protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] +# obs[i] = { +# "position": i - 1, +# "permission": rule_state["action"], +# "source_node_id": src_node_id, +# "source_port": src_port_id, +# "dest_node_id": dst_node_ip, +# "dest_port": dst_port_id, +# "protocol": protocol_id, +# } +# i += 1 +# return obs - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. +# @property +# def space(self) -> spaces.Space: +# """Gymnasium space object describing the observation space shape. - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - i - + 1: spaces.Dict( - { - "position": spaces.Discrete(self.num_rules), - "permission": spaces.Discrete(3), - # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "source_port": spaces.Discrete(len(self.port_to_id) + 2), - "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "dest_port": spaces.Discrete(len(self.port_to_id) + 2), - "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), - } - ) - for i in range(self.num_rules) - } - ) +# :return: Gymnasium space +# :rtype: spaces.Space +# """ +# return spaces.Dict( +# { +# i +# + 1: spaces.Dict( +# { +# "position": spaces.Discrete(self.num_rules), +# "permission": spaces.Discrete(3), +# # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) +# "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), +# "source_port": spaces.Discrete(len(self.port_to_id) + 2), +# "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), +# "dest_port": spaces.Discrete(len(self.port_to_id) + 2), +# "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), +# } +# ) +# for i in range(self.num_rules) +# } +# ) - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": - """Generate ACL observation from a config. +# @classmethod +# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": +# """Generate ACL observation from a config. - :param config: Dictionary containing the configuration for this ACL observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Observation object - :rtype: AclObservation - """ - max_acl_rules = config["options"]["max_acl_rules"] - node_ip_to_idx = {} - for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): - node_ref = ip_map_config["node_hostname"] - nic_num = ip_map_config["nic_num"] - node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] - nic_obj = node_obj.network_interface[nic_num] - node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 +# :param config: Dictionary containing the configuration for this ACL observation. +# :type config: Dict +# :param game: Reference to the PrimaiteGame object that spawned this observation. +# :type game: PrimaiteGame +# :return: Observation object +# :rtype: AclObservation +# """ +# max_acl_rules = config["options"]["max_acl_rules"] +# node_ip_to_idx = {} +# for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): +# node_ref = ip_map_config["node_hostname"] +# nic_num = ip_map_config["nic_num"] +# node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] +# nic_obj = node_obj.network_interface[nic_num] +# node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 - router_hostname = config["router_hostname"] - return cls( - node_ip_to_id=node_ip_to_idx, - ports=game.options.ports, - protocols=game.options.protocols, - where=["network", "nodes", router_hostname, "acl", "acl"], - num_rules=max_acl_rules, - ) +# router_hostname = config["router_hostname"] +# return cls( +# node_ip_to_id=node_ip_to_idx, +# ports=game.options.ports, +# protocols=game.options.protocols, +# where=["network", "nodes", router_hostname, "acl", "acl"], +# num_rules=max_acl_rules, +# ) -class NullObservation(AbstractObservation): - """Null observation, returns a single 0 value for the observation space.""" +# class NullObservation(AbstractObservation): +# """Null observation, returns a single 0 value for the observation space.""" - def __init__(self, where: Optional[List[str]] = None): - """Initialise null observation.""" - self.default_observation: Dict = {} +# def __init__(self, where: Optional[List[str]] = None): +# """Initialise null observation.""" +# self.default_observation: Dict = {} - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - return 0 +# def observe(self, state: Dict) -> Dict: +# """Generate observation based on the current state of the simulation.""" +# return 0 - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Discrete(1) +# @property +# def space(self) -> spaces.Space: +# """Gymnasium space object describing the observation space shape.""" +# return spaces.Discrete(1) - @classmethod - def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": - """ - Create null observation from a config. +# @classmethod +# def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": +# """ +# Create null observation from a config. - The parameters are ignored, they are here to match the signature of the other observation classes. - """ - return cls() +# The parameters are ignored, they are here to match the signature of the other observation classes. +# """ +# return cls() -class ICSObservation(NullObservation): - """ICS observation placeholder, currently not implemented so always returns a single 0.""" +# class ICSObservation(NullObservation): +# """ICS observation placeholder, currently not implemented so always returns a single 0.""" - pass +# pass diff --git a/src/primaite/game/agent/utils.py b/src/primaite/game/agent/utils.py index 1314087c..42e8f30b 100644 --- a/src/primaite/game/agent/utils.py +++ b/src/primaite/game/agent/utils.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Hashable, Sequence +from typing import Any, Dict, Hashable, Optional, Sequence NOT_PRESENT_IN_STATE = object() """ @@ -7,7 +7,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is """ -def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any: +def access_from_nested_dict(dictionary: Dict, keys: Optional[Sequence[Hashable]]) -> Any: """ Access an item from a deeply dictionary with a list of keys. @@ -21,6 +21,8 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any: :return: The value in the dictionary :rtype: Any """ + if keys is None: + return NOT_PRESENT_IN_STATE key_list = [*keys] # copy keys to a new list to prevent editing original list if len(key_list) == 0: return dictionary diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py index dce05b6a..a5195e1e 100644 --- a/tests/integration_tests/game_layer/observations/test_node_observations.py +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -4,7 +4,7 @@ from uuid import uuid4 import pytest from gymnasium import spaces -from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.game.agent.observations.host import NodeObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation