#2417 Add NestedObservation

This commit is contained in:
Marek Wolan
2024-03-31 17:31:10 +01:00
parent 22e1dfea2f
commit 15cb2e6970
10 changed files with 140 additions and 27 deletions

View File

@@ -40,7 +40,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
protocol_list: List[str],
) -> None:
"""
Initialize an ACL observation instance.
Initialise an ACL observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
:type where: WhereType

View File

@@ -25,7 +25,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
def __init__(self, where: WhereType, include_num_access: bool) -> None:
"""
Initialize a file observation instance.
Initialise a file observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this file.
A typical location for a file might be
@@ -107,7 +107,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
) -> None:
"""
Initialize a folder observation instance.
Initialise a folder observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].

View File

@@ -42,7 +42,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
num_rules: int,
) -> None:
"""
Initialize a firewall observation instance.
Initialise a firewall observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this firewall.
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].

View File

@@ -62,7 +62,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
include_num_access: bool,
) -> None:
"""
Initialize a host observation instance.
Initialise a host observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this host.
A typical location for a host might be ['network', 'nodes', <hostname>].

View File

@@ -22,7 +22,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
def __init__(self, where: WhereType, include_nmne: bool) -> None:
"""
Initialize a network interface observation instance.
Initialise a network interface observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this interface.
A typical location for a network interface might be
@@ -108,7 +108,7 @@ class PortObservation(AbstractObservation, identifier="PORT"):
def __init__(self, where: WhereType) -> None:
"""
Initialize a port observation instance.
Initialise a port observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this port.
A typical location for a port might be ['network', 'nodes', <node_hostname>, 'NICs', <port_id>].

View File

@@ -61,7 +61,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
firewalls: List[FirewallObservation],
) -> None:
"""
Initialize a nodes observation instance.
Initialise a nodes observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for nodes.
A typical location for nodes might be ['network', 'nodes'].

View File

@@ -1,6 +1,10 @@
from typing import Dict, TYPE_CHECKING
from __future__ import annotations
from typing import Any, Dict, List, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
from primaite.game.agent.observations.observations import AbstractObservation
@@ -8,6 +12,114 @@ if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NestedObservation(AbstractObservation, identifier="CUSTOM"):
"""Observation type that allows combining other observations into a gymnasium.spaces.Dict space."""
class NestedObservationItem(BaseModel):
"""One list item of the config."""
model_config = ConfigDict(extra="forbid")
type: str
"""Select observation class. It maps to the identifier of the obs class by checking the registry."""
label: str
"""Dict key in the final observation space."""
options: Dict
"""Options to pass to the observation class from_config method."""
@model_validator(mode="after")
def check_model(self) -> "NestedObservation.NestedObservationItem":
"""Make sure tha the config options match up with the selected observation type."""
obs_subclass_name = self.type
obs_options = self.options
if obs_subclass_name not in AbstractObservation._registry:
raise ValueError(f"Observation of type {obs_subclass_name} could not be found.")
obs_schema = AbstractObservation._registry[obs_subclass_name].ConfigSchema
try:
obs_schema(**obs_options)
except ValidationError as e:
raise ValueError(f"Observation options did not match schema, got this error: {e}")
return self
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NestedObservation."""
components: List[NestedObservation.NestedObservationItem]
"""List of observation components to be part of this space."""
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
"""Initialise nested observation."""
self.components: Dict[str, AbstractObservation] = components
"""Maps label: observation object"""
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
"""Default observation is just the default observations of constituents."""
def observe(self, state: Dict) -> Any:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status information about the host.
:rtype: ObsType
"""
return {label: obs.observe(state) for label, obs in self.components.items()}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the nested observation space.
:rtype: spaces.Space
"""
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
@classmethod
def from_config(cls, config: ConfigSchema) -> NestedObservation:
"""
Read the Nested observation config and create all defined subcomponents.
Example configuration that utilises NestedObservation:
This lets us have different options for different types of hosts.
```yaml
observation_space:
- type: CUSTOM
options:
components:
- type: HOSTS
label: COMPUTERS # What is the dictionary key called
options:
hosts:
- client_1
- client_2
num_services: 0
num_applications: 5
... # other options
- type: HOSTS
label: SERVERS # What is the dictionary key called
options:
hosts:
- hostname: database_server
- hostname: web_server
num_services: 4
num_applications: 0
num_folders: 2
num_files: 2
```
"""
instances = dict()
for component in config.components:
obs_class = AbstractObservation._registry[component.type]
obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options))
instances[component.label] = obs_instance
return cls(components=instances)
class ObservationManager:
"""
Manage the observations of an Agent.
@@ -18,18 +130,15 @@ class ObservationManager:
3. Formatting this information so an agent can use it to make decisions.
"""
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
# refactor.
def __init__(self, observation: AbstractObservation) -> None:
def __init__(self, obs: AbstractObservation) -> None:
"""Initialise observation space.
:param observation: Observation object
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.obs: AbstractObservation = obs
self.current_observation: ObsType
"""Cached copy of the observation at the time it was most recently calculated."""
def update(self, state: Dict) -> Dict:
"""
@@ -48,7 +157,8 @@ class ObservationManager:
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
"""
Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
It should contain the key 'type' which selects which observation class to use (from a choice of:
@@ -58,10 +168,8 @@ class ObservationManager:
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
for obs_cfg in config:
obs_type = obs_cfg['type']
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(obs_class.ConfigSchema(**obs_cfg['options']))
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"]))
obs_manager = cls(observation)
return obs_manager
return obs_manager

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, Iterable, Type
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
@@ -26,6 +27,10 @@ class AbstractObservation(ABC):
Automatically populated when subclasses are defined. Used for defining from_config.
"""
def __init__(self) -> None:
"""Initialise an observation. This method must be overwritten."""
self.default_observation: ObsType
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register an observation type.
@@ -58,10 +63,10 @@ class AbstractObservation(ABC):
pass
@classmethod
def from_config(cls, cfg: Dict) -> "AbstractObservation":
@abstractmethod
def from_config(cls, config: ConfigSchema) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
ObservationType = cls._registry[cfg["type"]]
return ObservationType.from_config(cfg=cfg)
return cls()
'''

View File

@@ -47,7 +47,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
acl: ACLObservation,
) -> None:
"""
Initialize a router observation instance.
Initialise a router observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this router.
A typical location for a router might be ['network', 'nodes', <node_hostname>].

View File

@@ -20,7 +20,7 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"):
def __init__(self, where: WhereType) -> None:
"""
Initialize a service observation instance.
Initialise a service observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this service.
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].