From 15cb2e6970a184c83c6d56c01ad3ae3f26660b1e Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Sun, 31 Mar 2024 17:31:10 +0100 Subject: [PATCH] #2417 Add NestedObservation --- .../agent/observations/acl_observation.py | 2 +- .../observations/file_system_observations.py | 4 +- .../observations/firewall_observation.py | 2 +- .../agent/observations/host_observations.py | 2 +- .../agent/observations/nic_observations.py | 4 +- .../agent/observations/node_observations.py | 2 +- .../agent/observations/observation_manager.py | 136 ++++++++++++++++-- .../game/agent/observations/observations.py | 11 +- .../agent/observations/router_observation.py | 2 +- .../observations/software_observation.py | 2 +- 10 files changed, 140 insertions(+), 27 deletions(-) diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py index 2d29223d..7601e678 100644 --- a/src/primaite/game/agent/observations/acl_observation.py +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -40,7 +40,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"): protocol_list: List[str], ) -> None: """ - Initialize an ACL observation instance. + Initialise an ACL observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this ACL. :type where: WhereType diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index a30bfc82..3c931bc8 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -25,7 +25,7 @@ class FileObservation(AbstractObservation, identifier="FILE"): def __init__(self, where: WhereType, include_num_access: bool) -> None: """ - Initialize a file observation instance. + Initialise a file observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this file. A typical location for a file might be @@ -107,7 +107,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool ) -> None: """ - Initialize a folder observation instance. + Initialise a folder observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this folder. A typical location for a folder might be ['network', 'nodes', , 'folders', ]. diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py index 6397d473..376e4824 100644 --- a/src/primaite/game/agent/observations/firewall_observation.py +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -42,7 +42,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): num_rules: int, ) -> None: """ - Initialize a firewall observation instance. + Initialise a firewall observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this firewall. A typical location for a firewall might be ['network', 'nodes', ]. diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index 34c9b3ff..9146979a 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -62,7 +62,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): include_num_access: bool, ) -> None: """ - Initialize a host observation instance. + Initialise a host observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this host. A typical location for a host might be ['network', 'nodes', ]. diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index 3be53112..ff2731ff 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -22,7 +22,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): def __init__(self, where: WhereType, include_nmne: bool) -> None: """ - Initialize a network interface observation instance. + Initialise a network interface observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this interface. A typical location for a network interface might be @@ -108,7 +108,7 @@ class PortObservation(AbstractObservation, identifier="PORT"): def __init__(self, where: WhereType) -> None: """ - Initialize a port observation instance. + Initialise a port observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this port. A typical location for a port might be ['network', 'nodes', , 'NICs', ]. diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 0e63f440..3f384ece 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -61,7 +61,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): firewalls: List[FirewallObservation], ) -> None: """ - Initialize a nodes observation instance. + Initialise a nodes observation instance. :param where: Where in the simulation state dictionary to find the relevant information for nodes. A typical location for nodes might be ['network', 'nodes']. diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index be90041e..a6981ddc 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -1,6 +1,10 @@ -from typing import Dict, TYPE_CHECKING +from __future__ import annotations +from typing import Any, Dict, List, TYPE_CHECKING + +from gymnasium import spaces from gymnasium.core import ObsType +from pydantic import BaseModel, ConfigDict, model_validator, ValidationError from primaite.game.agent.observations.observations import AbstractObservation @@ -8,6 +12,114 @@ if TYPE_CHECKING: from primaite.game.game import PrimaiteGame +class NestedObservation(AbstractObservation, identifier="CUSTOM"): + """Observation type that allows combining other observations into a gymnasium.spaces.Dict space.""" + + class NestedObservationItem(BaseModel): + """One list item of the config.""" + + model_config = ConfigDict(extra="forbid") + type: str + """Select observation class. It maps to the identifier of the obs class by checking the registry.""" + label: str + """Dict key in the final observation space.""" + options: Dict + """Options to pass to the observation class from_config method.""" + + @model_validator(mode="after") + def check_model(self) -> "NestedObservation.NestedObservationItem": + """Make sure tha the config options match up with the selected observation type.""" + obs_subclass_name = self.type + obs_options = self.options + if obs_subclass_name not in AbstractObservation._registry: + raise ValueError(f"Observation of type {obs_subclass_name} could not be found.") + obs_schema = AbstractObservation._registry[obs_subclass_name].ConfigSchema + try: + obs_schema(**obs_options) + except ValidationError as e: + raise ValueError(f"Observation options did not match schema, got this error: {e}") + return self + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for NestedObservation.""" + + components: List[NestedObservation.NestedObservationItem] + """List of observation components to be part of this space.""" + + def __init__(self, components: Dict[str, AbstractObservation]) -> None: + """Initialise nested observation.""" + self.components: Dict[str, AbstractObservation] = components + """Maps label: observation object""" + + self.default_observation = {label: obs.default_observation for label, obs in self.components.items()} + """Default observation is just the default observations of constituents.""" + + def observe(self, state: Dict) -> Any: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status information about the host. + :rtype: ObsType + """ + return {label: obs.observe(state) for label, obs in self.components.items()} + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the nested observation space. + :rtype: spaces.Space + """ + return spaces.Dict({label: obs.space for label, obs in self.components.items()}) + + @classmethod + def from_config(cls, config: ConfigSchema) -> NestedObservation: + """ + Read the Nested observation config and create all defined subcomponents. + + Example configuration that utilises NestedObservation: + This lets us have different options for different types of hosts. + + ```yaml + observation_space: + - type: CUSTOM + options: + components: + + - type: HOSTS + label: COMPUTERS # What is the dictionary key called + options: + hosts: + - client_1 + - client_2 + num_services: 0 + num_applications: 5 + ... # other options + + - type: HOSTS + label: SERVERS # What is the dictionary key called + options: + hosts: + - hostname: database_server + - hostname: web_server + num_services: 4 + num_applications: 0 + num_folders: 2 + num_files: 2 + + ``` + """ + instances = dict() + for component in config.components: + obs_class = AbstractObservation._registry[component.type] + obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options)) + instances[component.label] = obs_instance + return cls(components=instances) + + class ObservationManager: """ Manage the observations of an Agent. @@ -18,18 +130,15 @@ class ObservationManager: 3. Formatting this information so an agent can use it to make decisions. """ - # TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed - # to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next - # refactor. - - def __init__(self, observation: AbstractObservation) -> None: + def __init__(self, obs: AbstractObservation) -> None: """Initialise observation space. :param observation: Observation object :type observation: AbstractObservation """ - self.obs: AbstractObservation = observation + self.obs: AbstractObservation = obs self.current_observation: ObsType + """Cached copy of the observation at the time it was most recently calculated.""" def update(self, state: Dict) -> Dict: """ @@ -48,7 +157,8 @@ class ObservationManager: @classmethod def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": - """Create observation space from a config. + """ + Create observation space from a config. :param config: Dictionary containing the configuration for this observation space. It should contain the key 'type' which selects which observation class to use (from a choice of: @@ -58,10 +168,8 @@ class ObservationManager: :param game: Reference to the PrimaiteGame object that spawned this observation. :type game: PrimaiteGame """ - - for obs_cfg in config: - obs_type = obs_cfg['type'] - obs_class = AbstractObservation._registry[obs_type] - observation = obs_class.from_config(obs_class.ConfigSchema(**obs_cfg['options'])) + obs_type = config["type"] + obs_class = AbstractObservation._registry[obs_type] + observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"])) obs_manager = cls(observation) - return obs_manager \ No newline at end of file + return obs_manager diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index 08871072..feddc3ed 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -3,6 +3,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Iterable, Type from gymnasium import spaces +from gymnasium.core import ObsType from pydantic import BaseModel, ConfigDict from primaite import getLogger @@ -26,6 +27,10 @@ class AbstractObservation(ABC): Automatically populated when subclasses are defined. Used for defining from_config. """ + def __init__(self) -> None: + """Initialise an observation. This method must be overwritten.""" + self.default_observation: ObsType + def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None: """ Register an observation type. @@ -58,10 +63,10 @@ class AbstractObservation(ABC): pass @classmethod - def from_config(cls, cfg: Dict) -> "AbstractObservation": + @abstractmethod + def from_config(cls, config: ConfigSchema) -> "AbstractObservation": """Create this observation space component form a serialised format.""" - ObservationType = cls._registry[cfg["type"]] - return ObservationType.from_config(cfg=cfg) + return cls() ''' diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py index b8dee2c2..97d8ab41 100644 --- a/src/primaite/game/agent/observations/router_observation.py +++ b/src/primaite/game/agent/observations/router_observation.py @@ -47,7 +47,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): acl: ACLObservation, ) -> None: """ - Initialize a router observation instance. + Initialise a router observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this router. A typical location for a router might be ['network', 'nodes', ]. diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index eb94651d..0c031345 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -20,7 +20,7 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"): def __init__(self, where: WhereType) -> None: """ - Initialize a service observation instance. + Initialise a service observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this service. A typical location for a service might be ['network', 'nodes', , 'services', ].