#2417 Add NestedObservation
This commit is contained in:
@@ -40,7 +40,7 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
protocol_list: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an ACL observation instance.
|
||||
Initialise an ACL observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
|
||||
:type where: WhereType
|
||||
|
||||
@@ -25,7 +25,7 @@ class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool) -> None:
|
||||
"""
|
||||
Initialize a file observation instance.
|
||||
Initialise a file observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this file.
|
||||
A typical location for a file might be
|
||||
@@ -107,7 +107,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a folder observation instance.
|
||||
Initialise a folder observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].
|
||||
|
||||
@@ -42,7 +42,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
num_rules: int,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a firewall observation instance.
|
||||
Initialise a firewall observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this firewall.
|
||||
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
|
||||
|
||||
@@ -62,7 +62,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
include_num_access: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a host observation instance.
|
||||
Initialise a host observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this host.
|
||||
A typical location for a host might be ['network', 'nodes', <hostname>].
|
||||
|
||||
@@ -22,7 +22,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
|
||||
def __init__(self, where: WhereType, include_nmne: bool) -> None:
|
||||
"""
|
||||
Initialize a network interface observation instance.
|
||||
Initialise a network interface observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this interface.
|
||||
A typical location for a network interface might be
|
||||
@@ -108,7 +108,7 @@ class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialize a port observation instance.
|
||||
Initialise a port observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this port.
|
||||
A typical location for a port might be ['network', 'nodes', <node_hostname>, 'NICs', <port_id>].
|
||||
|
||||
@@ -61,7 +61,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
firewalls: List[FirewallObservation],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a nodes observation instance.
|
||||
Initialise a nodes observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for nodes.
|
||||
A typical location for nodes might be ['network', 'nodes'].
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
|
||||
@@ -8,6 +12,114 @@ if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
"""Observation type that allows combining other observations into a gymnasium.spaces.Dict space."""
|
||||
|
||||
class NestedObservationItem(BaseModel):
|
||||
"""One list item of the config."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
"""Select observation class. It maps to the identifier of the obs class by checking the registry."""
|
||||
label: str
|
||||
"""Dict key in the final observation space."""
|
||||
options: Dict
|
||||
"""Options to pass to the observation class from_config method."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_model(self) -> "NestedObservation.NestedObservationItem":
|
||||
"""Make sure tha the config options match up with the selected observation type."""
|
||||
obs_subclass_name = self.type
|
||||
obs_options = self.options
|
||||
if obs_subclass_name not in AbstractObservation._registry:
|
||||
raise ValueError(f"Observation of type {obs_subclass_name} could not be found.")
|
||||
obs_schema = AbstractObservation._registry[obs_subclass_name].ConfigSchema
|
||||
try:
|
||||
obs_schema(**obs_options)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Observation options did not match schema, got this error: {e}")
|
||||
return self
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NestedObservation."""
|
||||
|
||||
components: List[NestedObservation.NestedObservationItem]
|
||||
"""List of observation components to be part of this space."""
|
||||
|
||||
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
|
||||
"""Initialise nested observation."""
|
||||
self.components: Dict[str, AbstractObservation] = components
|
||||
"""Maps label: observation object"""
|
||||
|
||||
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
|
||||
"""Default observation is just the default observations of constituents."""
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status information about the host.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
return {label: obs.observe(state) for label, obs in self.components.items()}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the nested observation space.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema) -> NestedObservation:
|
||||
"""
|
||||
Read the Nested observation config and create all defined subcomponents.
|
||||
|
||||
Example configuration that utilises NestedObservation:
|
||||
This lets us have different options for different types of hosts.
|
||||
|
||||
```yaml
|
||||
observation_space:
|
||||
- type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
|
||||
- type: HOSTS
|
||||
label: COMPUTERS # What is the dictionary key called
|
||||
options:
|
||||
hosts:
|
||||
- client_1
|
||||
- client_2
|
||||
num_services: 0
|
||||
num_applications: 5
|
||||
... # other options
|
||||
|
||||
- type: HOSTS
|
||||
label: SERVERS # What is the dictionary key called
|
||||
options:
|
||||
hosts:
|
||||
- hostname: database_server
|
||||
- hostname: web_server
|
||||
num_services: 4
|
||||
num_applications: 0
|
||||
num_folders: 2
|
||||
num_files: 2
|
||||
|
||||
```
|
||||
"""
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(obs_class.ConfigSchema(**component.options))
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
|
||||
class ObservationManager:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
@@ -18,18 +130,15 @@ class ObservationManager:
|
||||
3. Formatting this information so an agent can use it to make decisions.
|
||||
"""
|
||||
|
||||
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
|
||||
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
|
||||
# refactor.
|
||||
|
||||
def __init__(self, observation: AbstractObservation) -> None:
|
||||
def __init__(self, obs: AbstractObservation) -> None:
|
||||
"""Initialise observation space.
|
||||
|
||||
:param observation: Observation object
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
self.obs: AbstractObservation = obs
|
||||
self.current_observation: ObsType
|
||||
"""Cached copy of the observation at the time it was most recently calculated."""
|
||||
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
@@ -48,7 +157,8 @@ class ObservationManager:
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
"""
|
||||
Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
It should contain the key 'type' which selects which observation class to use (from a choice of:
|
||||
@@ -58,10 +168,8 @@ class ObservationManager:
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
|
||||
for obs_cfg in config:
|
||||
obs_type = obs_cfg['type']
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(obs_class.ConfigSchema(**obs_cfg['options']))
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(obs_class.ConfigSchema(**config["options"]))
|
||||
obs_manager = cls(observation)
|
||||
return obs_manager
|
||||
return obs_manager
|
||||
|
||||
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Iterable, Type
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -26,6 +27,10 @@ class AbstractObservation(ABC):
|
||||
Automatically populated when subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise an observation. This method must be overwritten."""
|
||||
self.default_observation: ObsType
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an observation type.
|
||||
@@ -58,10 +63,10 @@ class AbstractObservation(ABC):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict) -> "AbstractObservation":
|
||||
@abstractmethod
|
||||
def from_config(cls, config: ConfigSchema) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
ObservationType = cls._registry[cfg["type"]]
|
||||
return ObservationType.from_config(cfg=cfg)
|
||||
return cls()
|
||||
|
||||
|
||||
'''
|
||||
|
||||
@@ -47,7 +47,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
acl: ACLObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a router observation instance.
|
||||
Initialise a router observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this router.
|
||||
A typical location for a router might be ['network', 'nodes', <node_hostname>].
|
||||
|
||||
@@ -20,7 +20,7 @@ class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialize a service observation instance.
|
||||
Initialise a service observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this service.
|
||||
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].
|
||||
|
||||
Reference in New Issue
Block a user