New observations

This commit is contained in:
Marek Wolan
2024-03-27 22:11:02 +00:00
parent fbb4eba6b7
commit cae9f64b93
6 changed files with 727 additions and 309 deletions

View File

@@ -2,7 +2,7 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.game.agent.observations.host import NodeObservation
from primaite.game.agent.observations.observations import (
AbstractObservation,
AclObservation,
@@ -136,53 +136,3 @@ class UC2BlueObservation(AbstractObservation):
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
class UC2RedObservation(AbstractObservation):
"""Container for all observations used by the red agent in UC2."""
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
class UC2GreenObservation(NullObservation):
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
pass

View File

@@ -1,21 +1,473 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.observations.file_system_observations import FolderObservation
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.software_observation import ServiceObservation
# from primaite.game.agent.observations.file_system_observations import FolderObservation
# from primaite.game.agent.observations.nic_observations import NicObservation
# from primaite.game.agent.observations.software_observation import ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
WhereType = Iterable[str | int] | None
class NodeObservation(AbstractObservation):
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
class ConfigSchema(AbstractObservation.ConfigSchema):
service_name: str
def __init__(self, where: WhereType)->None:
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0}
def observe(self, state: Dict) -> Any:
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {
"operating_status": service_state["operating_state"],
"health_status": service_state["health_state_visible"],
}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
return cls(where=parent_where+["services", config.service_name])
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
class ConfigSchema(AbstractObservation.ConfigSchema):
application_name: str
def __init__(self, where: WhereType)->None:
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
def observe(self, state: Dict) -> Any:
# raise NotImplementedError("TODO NUM EXECUTIONS NEEDS TO BE CONVERTED TO A CATEGORICAL")
application_state = access_from_nested_dict(state, self.where)
if application_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {
"operating_status": application_state["operating_state"],
"health_status": application_state["health_state_visible"],
"num_executions": application_state["num_executions"],
}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({
"operating_status": spaces.Discrete(7),
"health_status": spaces.Discrete(5),
"num_executions": spaces.Discrete(4)
})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ApplicationObservation:
return cls(where=parent_where+["applications", config.application_name])
class FileObservation(AbstractObservation, identifier="FILE"):
class ConfigSchema(AbstractObservation.ConfigSchema):
file_name: str
include_num_access : bool = False
def __init__(self, where: WhereType, include_num_access: bool)->None:
self.where: WhereType = where
self.include_num_access :bool = include_num_access
self.default_observation: ObsType = {"health_status": 0}
if self.include_num_access:
self.default_observation["num_access"] = 0
def observe(self, state: Dict) -> Any:
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {"health_status": file_state["visible_status"]}
if self.include_num_access:
obs["num_access"] = file_state["num_access"]
# raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.")
return obs
@property
def space(self) -> spaces.Space:
space = {"health_status": spaces.Discrete(6)}
if self.include_num_access:
space["num_access"] = spaces.Discrete(4)
return spaces.Dict(space)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation:
return cls(where=parent_where+["files", config.file_name], include_num_access=config.include_num_access)
class FolderObservation(AbstractObservation, identifier="FOLDER"):
class ConfigSchema(AbstractObservation.ConfigSchema):
folder_name: str
files: List[FileObservation.ConfigSchema] = []
num_files : int = 0
include_num_access : bool = False
def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None:
self.where: WhereType = where
self.files: List[FileObservation] = files
while len(self.files) < num_files:
self.files.append(FileObservation(where=None,include_num_access=include_num_access))
while len(self.files) > num_files:
truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
_LOGGER.warning(msg)
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
def observe(self, state: Dict) -> Any:
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
health_status = folder_state["health_status"]
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation:
where = parent_where + ["folders", config.folder_name]
#pass down shared/common config items
for file_config in config.files:
file_config.include_num_access = config.include_num_access
files = [FileObservation.from_config(config=f, parent_where = where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
class ConfigSchema(AbstractObservation.ConfigSchema):
nic_num: int
include_nmne: bool = False
def __init__(self, where: WhereType, include_nmne: bool)->None:
self.where = where
self.include_nmne : bool = include_nmne
self.default_observation: ObsType = {"nic_status": 0}
if self.include_nmne:
self.default_observation.update({"NMNE":{"inbound":0, "outbound":0}})
def observe(self, state: Dict) -> Any:
# raise NotImplementedError("TODO: CATEGORISATION")
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
if self.include_nmne:
obs.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs
@property
def space(self) -> spaces.Space:
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
if self.include_nmne:
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne)
class HostObservation(AbstractObservation, identifier="HOST"):
class ConfigSchema(AbstractObservation.ConfigSchema):
hostname: str
services: List[ServiceObservation.ConfigSchema] = []
applications: List[ApplicationObservation.ConfigSchema] = []
folders: List[FolderObservation.ConfigSchema] = []
network_interfaces: List[NICObservation.ConfigSchema] = []
num_services: int
num_applications: int
num_folders: int
num_files: int
num_nics: int
include_nmne: bool
include_num_access: bool
def __init__(self,
where: WhereType,
services:List[ServiceObservation],
applications:List[ApplicationObservation],
folders:List[FolderObservation],
network_interfaces:List[NICObservation],
num_services: int,
num_applications: int,
num_folders: int,
num_files: int,
num_nics: int,
include_nmne: bool,
include_num_access: bool
)->None:
self.where : WhereType = where
# ensure service list has length equal to num_services by truncating or padding
self.services: List[ServiceObservation] = services
while len(self.services) < num_services:
self.services.append(ServiceObservation(where=None))
while len(self.services) > num_services:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warning(msg)
# ensure application list has length equal to num_applications by truncating or padding
self.applications: List[ApplicationObservation] = applications
while len(self.applications) < num_applications:
self.applications.append(ApplicationObservation(where=None))
while len(self.applications) > num_applications:
truncated_application = self.applications.pop()
msg = f"Too many applications in Node observation space for node. Truncating application {truncated_application.where}"
_LOGGER.warning(msg)
# ensure folder list has length equal to num_folders by truncating or padding
self.folders: List[FolderObservation] = folders
while len(self.folders) < num_folders:
self.folders.append(FolderObservation(where = None, files= [], num_files=num_files, include_num_access=include_num_access))
while len(self.folders) > num_folders:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
_LOGGER.warning(msg)
# ensure network_interface list has length equal to num_network_interfaces by truncating or padding
self.network_interfaces: List[NICObservation] = network_interfaces
while len(self.network_interfaces) < num_nics:
self.network_interfaces.append(NICObservation(where = None, include_nmne=include_nmne))
while len(self.network_interfaces) > num_nics:
truncated_nic = self.network_interfaces.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_folder.where}"
_LOGGER.warning(msg)
self.default_observation: ObsType = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
"num_file_creations": 0,
"num_file_deletions": 0,
}
def observe(self, state: Dict) -> Any:
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
}
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
return obs
@property
def space(self) -> spaces.Space:
shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
"num_file_creations" : spaces.Discrete(4),
"num_file_deletions" : spaces.Discrete(4),
}
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> ServiceObservation:
if parent_where is None:
where = ["network", "nodes", config.hostname]
else:
where = parent_where + ["nodes", config.hostname]
#pass down shared/common config items
for folder_config in config.folders:
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
services = [ServiceObservation.from_config(config=c,parent_where=where) for c in config.services]
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
return cls(
where = where,
services = services,
applications = applications,
folders = folders,
network_interfaces = nics,
num_services = config.num_services,
num_applications = config.num_applications,
num_folders = config.num_folders,
num_files = config.num_files,
num_nics = config.num_nics,
include_nmne = config.include_nmne,
include_num_access = config.include_num_access,
)
class PortObservation(AbstractObservation, identifier="PORT"):
class ConfigSchema(AbstractObservation.ConfigSchema):
pass
def __init__(self, where: WhereType)->None:
pass
def observe(self, state: Dict) -> Any:
pass
@property
def space(self) -> spaces.Space:
pass
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
class ACLObservation(AbstractObservation, identifier="ACL"):
class ConfigSchema(AbstractObservation.ConfigSchema):
pass
def __init__(self, where: WhereType)->None:
pass
def observe(self, state: Dict) -> Any:
pass
@property
def space(self) -> spaces.Space:
pass
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
class RouterObservation(AbstractObservation, identifier="ROUTER"):
class ConfigSchema(AbstractObservation.ConfigSchema):
hostname: str
ports: List[PortObservation.ConfigSchema]
def __init__(self, where: WhereType)->None:
pass
def observe(self, state: Dict) -> Any:
pass
@property
def space(self) -> spaces.Space:
pass
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
class ConfigSchema(AbstractObservation.ConfigSchema):
hostname: str
ports: List[PortObservation.ConfigSchema] = []
def __init__(self, where: WhereType)->None:
pass
def observe(self, state: Dict) -> Any:
pass
@property
def space(self) -> spaces.Space:
pass
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
class NodesObservation(AbstractObservation, identifier="NODES"):
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Config"""
hosts: List[HostObservation.ConfigSchema] = []
routers: List[RouterObservation.ConfigSchema] = []
firewalls: List[FirewallObservation.ConfigSchema] = []
num_services: int = 1
def __init__(self, where: WhereType)->None:
pass
def observe(self, state: Dict) -> Any:
pass
@property
def space(self) -> spaces.Space:
pass
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
############################ OLD
class NodeObservation(AbstractObservation, identifier= "OLD"):
"""Observation of a node in the network. Includes services, folders and NICs."""
def __init__(

View File

@@ -2,11 +2,6 @@ from typing import Dict, TYPE_CHECKING
from gymnasium.core import ObsType
from primaite.game.agent.observations.agent_observations import (
UC2BlueObservation,
UC2GreenObservation,
UC2RedObservation,
)
from primaite.game.agent.observations.observations import AbstractObservation
if TYPE_CHECKING:
@@ -63,11 +58,10 @@ class ObservationManager:
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
else:
raise ValueError("Observation space type invalid")
for obs_cfg in config:
obs_type = obs_cfg['type']
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(obs_class.ConfigSchema(**obs_cfg['options']))
obs_manager = cls(observation)
return obs_manager

View File

@@ -1,9 +1,10 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Type
from gymnasium import spaces
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -17,6 +18,28 @@ if TYPE_CHECKING:
class AbstractObservation(ABC):
"""Abstract class for an observation space component."""
class ConfigSchema(ABC, BaseModel):
model_config = ConfigDict(extra="forbid")
_registry: Dict[str, Type["AbstractObservation"]] = {}
"""Registry of observation components, with their name as key.
Automatically populated when subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register an observation type.
:param identifier: Identifier used to uniquely specify observation component types.
:type identifier: str
:raises ValueError: When attempting to create a component with a name that is already in use.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate observation component type {identifier}")
cls._registry[identifier] = cls
@abstractmethod
def observe(self, state: Dict) -> Any:
"""
@@ -36,274 +59,271 @@ class AbstractObservation(ABC):
pass
@classmethod
@abstractmethod
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `game` parameter is for a the PrimaiteGame object that spawns this component.
"""
pass
def from_config(cls, cfg: Dict) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
ObservationType = cls._registry[cfg['type']]
return ObservationType.from_config(cfg=cfg)
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
# class LinkObservation(AbstractObservation):
# """Observation of a link in the network."""
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
# default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
# "Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
# def __init__(self, where: Optional[Tuple[str]] = None) -> None:
# """Initialise link observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
# :param where: Store information about where in the simulation state dictionary to find the relevant information.
# Optional. If None, this corresponds that the file does not exist and the observation will be populated with
# zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
# A typical location for a service looks like this:
# `['network','nodes',<node_hostname>,'servics', <service_name>]`
# :type where: Optional[List[str]]
# """
# super().__init__()
# self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
# def observe(self, state: Dict) -> Dict:
# """Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
# :param state: Simulation state dictionary
# :type state: Dict
# :return: Observation
# :rtype: Dict
# """
# if self.where is None:
# return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# link_state = access_from_nested_dict(state, self.where)
# if link_state is NOT_PRESENT_IN_STATE:
# return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# bandwidth = link_state["bandwidth"]
# load = link_state["current_load"]
# if load == 0:
# utilisation_category = 0
# else:
# utilisation_fraction = load / bandwidth
# # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
# utilisation_category = int(utilisation_fraction * 9) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
# # TODO: once the links support separte load per protocol, this needs amendment to reflect that.
# return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
# @property
# def space(self) -> spaces.Space:
# """Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
# :return: Gymnasium space
# :rtype: spaces.Space
# """
# return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
# @classmethod
# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
# """Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
# :param config: Dictionary containing the configuration for this link observation.
# :type config: Dict
# :param game: Reference to the PrimaiteGame object that spawned this observation.
# :type game: PrimaiteGame
# :return: Constructed link observation
# :rtype: LinkObservation
# """
# return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# class AclObservation(AbstractObservation):
# """Observation of an Access Control List (ACL) in the network."""
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
# # TODO: should where be optional, and we can use where=None to pad the observation space?
# # definitely the current approach does not support tracking files that aren't specified by name, for example
# # if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# # this needs adding, but not for the MVP.
# def __init__(
# self,
# node_ip_to_id: Dict[str, int],
# ports: List[int],
# protocols: List[str],
# where: Optional[Tuple[str]] = None,
# num_rules: int = 10,
# ) -> None:
# """Initialise ACL observation.
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
# :param node_ip_to_id: Mapping between IP address and ID.
# :type node_ip_to_id: Dict[str, int]
# :param ports: List of ports which are part of the game that define the ordering when converting to an ID
# :type ports: List[int]
# :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
# :type protocols: list[str]
# :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
# example may look like this:
# ['network','nodes',<router_hostname>,'acl','acl']
# :type where: Optional[Tuple[str]], optional
# :param num_rules: , defaults to 10
# :type num_rules: int, optional
# """
# super().__init__()
# self.where: Optional[Tuple[str]] = where
# self.num_rules: int = num_rules
# self.node_to_id: Dict[str, int] = node_ip_to_id
# "List of node IP addresses, order in this list determines how they are converted to an ID"
# self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
# "List of ports which are part of the game that define the ordering when converting to an ID"
# self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
# "List of protocols which are part of the game, defines ordering when converting to an ID"
# self.default_observation: Dict = {
# i
# + 1: {
# "position": i,
# "permission": 0,
# "source_node_id": 0,
# "source_port": 0,
# "dest_node_id": 0,
# "dest_port": 0,
# "protocol": 0,
# }
# for i in range(self.num_rules)
# }
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
# def observe(self, state: Dict) -> Dict:
# """Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# :param state: Simulation state dictionary
# :type state: Dict
# :return: Observation
# :rtype: Dict
# """
# if self.where is None:
# return self.default_observation
# acl_state: Dict = access_from_nested_dict(state, self.where)
# if acl_state is NOT_PRESENT_IN_STATE:
# return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
# # TODO: what if the ACL has more rules than num of max rules for obs space
# obs = {}
# acl_items = dict(acl_state.items())
# i = 1 # don't show rule 0 for compatibility reasons.
# while i < self.num_rules + 1:
# rule_state = acl_items[i]
# if rule_state is None:
# obs[i] = {
# "position": i - 1,
# "permission": 0,
# "source_node_id": 0,
# "source_port": 0,
# "dest_node_id": 0,
# "dest_port": 0,
# "protocol": 0,
# }
# else:
# src_ip = rule_state["src_ip_address"]
# src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
# dst_ip = rule_state["dst_ip_address"]
# dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
# src_port = rule_state["src_port"]
# src_port_id = 1 if src_port is None else self.port_to_id[src_port]
# dst_port = rule_state["dst_port"]
# dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
# protocol = rule_state["protocol"]
# protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
# obs[i] = {
# "position": i - 1,
# "permission": rule_state["action"],
# "source_node_id": src_node_id,
# "source_port": src_port_id,
# "dest_node_id": dst_node_ip,
# "dest_port": dst_port_id,
# "protocol": protocol_id,
# }
# i += 1
# return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
# @property
# def space(self) -> spaces.Space:
# """Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
# :return: Gymnasium space
# :rtype: spaces.Space
# """
# return spaces.Dict(
# {
# i
# + 1: spaces.Dict(
# {
# "position": spaces.Discrete(self.num_rules),
# "permission": spaces.Discrete(3),
# # adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
# "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
# "source_port": spaces.Discrete(len(self.port_to_id) + 2),
# "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
# "dest_port": spaces.Discrete(len(self.port_to_id) + 2),
# "protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
# }
# )
# for i in range(self.num_rules)
# }
# )
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
# @classmethod
# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
# """Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_hostname"]
nic_num = ip_map_config["nic_num"]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.network_interface[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
# :param config: Dictionary containing the configuration for this ACL observation.
# :type config: Dict
# :param game: Reference to the PrimaiteGame object that spawned this observation.
# :type game: PrimaiteGame
# :return: Observation object
# :rtype: AclObservation
# """
# max_acl_rules = config["options"]["max_acl_rules"]
# node_ip_to_idx = {}
# for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
# node_ref = ip_map_config["node_hostname"]
# nic_num = ip_map_config["nic_num"]
# node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
# nic_obj = node_obj.network_interface[nic_num]
# node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
# router_hostname = config["router_hostname"]
# return cls(
# node_ip_to_id=node_ip_to_idx,
# ports=game.options.ports,
# protocols=game.options.protocols,
# where=["network", "nodes", router_hostname, "acl", "acl"],
# num_rules=max_acl_rules,
# )
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
# class NullObservation(AbstractObservation):
# """Null observation, returns a single 0 value for the observation space."""
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
# def __init__(self, where: Optional[List[str]] = None):
# """Initialise null observation."""
# self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
# def observe(self, state: Dict) -> Dict:
# """Generate observation based on the current state of the simulation."""
# return 0
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
# @property
# def space(self) -> spaces.Space:
# """Gymnasium space object describing the observation space shape."""
# return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
# @classmethod
# def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
# """
# Create null observation from a config.
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
return cls()
# The parameters are ignored, they are here to match the signature of the other observation classes.
# """
# return cls()
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
# class ICSObservation(NullObservation):
# """ICS observation placeholder, currently not implemented so always returns a single 0."""
pass
# pass

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Hashable, Sequence
from typing import Any, Dict, Hashable, Optional, Sequence
NOT_PRESENT_IN_STATE = object()
"""
@@ -7,7 +7,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is
"""
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
def access_from_nested_dict(dictionary: Dict, keys: Optional[Sequence[Hashable]]) -> Any:
"""
Access an item from a deeply dictionary with a list of keys.
@@ -21,6 +21,8 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
:return: The value in the dictionary
:rtype: Any
"""
if keys is None:
return NOT_PRESENT_IN_STATE
key_list = [*keys] # copy keys to a new list to prevent editing original list
if len(key_list) == 0:
return dictionary

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.game.agent.observations.host import NodeObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.sim_container import Simulation