Files
PrimAITE/src/primaite/game/agent/observations.py
2023-11-22 11:59:25 +00:00

986 lines
41 KiB
Python

"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class AbstractObservation(ABC):
"""Abstract class for an observation space component."""
@abstractmethod
def observe(self, state: Dict) -> Any:
"""
Return an observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Any
"""
pass
@property
@abstractmethod
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space."""
pass
@classmethod
@abstractmethod
def from_config(cls, config: Dict, session: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
"""
pass
class FileObservation(AbstractObservation):
"""Observation of a file on a node in the network."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
Initialise file observation.
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"health_status": file_state["health_status"]}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param session: _description_
:type session: PrimaiteSession
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
:rtype: _type_
"""
return cls(where=parent_where + ["files", config["file_name"]])
class ServiceObservation(AbstractObservation):
"""Observation of a service in the network."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise service observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"operating_status": service_state["operating_state"], "health_status": service_state["health_state"]}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 10) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": utilisation_category}}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
class FolderObservation(AbstractObservation):
"""Folder observation, including files inside of the folder."""
def __init__(
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
) -> None:
"""Initialise folder Observation, including files inside of the folder.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
:type max_files: int, optional
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
that even if new files are created, the existing files will always occupy the same space in the observation
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
name, it will take the position defined in this dict. Defaults to {}
:type file_positions: Dict[int, str], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.files: List[FileObservation] = files
while len(self.files) < num_files_per_folder:
self.files.append(FileObservation())
while len(self.files) > num_files_per_folder:
truncated_file = self.files.pop()
msg = f"Too many files in folde observation. Truncating file {truncated_file}"
_LOGGER.warn(msg)
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
health_status = folder_state["health_status"]
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
@classmethod
def from_config(
cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_uuid>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed folder observation
:rtype: FolderObservation
"""
where = parent_where + ["folders", config["folder_name"]]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
default_observation: spaces.Space = {"nic_status": 0}
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise NIC observation.
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_uuid>,'NICs',<nic_uuid>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
return {"nic_status": 1 if nic_state["enabled"] else 2}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"nic_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_uuid>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
"""
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
class NodeObservation(AbstractObservation):
"""Observation of a node in the network. Includes services, folders and NICs."""
def __init__(
self,
where: Optional[Tuple[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
nics: List[NicObservation] = [],
logon_status: bool = False,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> None:
"""
Configurable observation for a node in the simulation.
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service UUID, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
:type max_services: int, optional
:param folders: Mapping between position in observation space and folder name, defaults to {}
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
:type nics: Dict[int,str], optional
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
:type max_nics: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.services: List[ServiceObservation] = services
while len(self.services) < num_services_per_node:
# add empty service observation without `where` parameter so it always returns default (blank) observation
self.services.append(ServiceObservation())
while len(self.services) > num_services_per_node:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warn(msg)
# truncate service list
self.folders: List[FolderObservation] = folders
# add empty folder observation without `where` parameter that will always return default (blank) observations
while len(self.folders) < num_folders_per_node:
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warn(msg)
self.nics: List[NicObservation] = nics
while len(self.nics) < num_nics_per_node:
self.nics.append(NicObservation())
while len(self.nics) > num_nics_per_node:
truncated_nic = self.nics.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warn(msg)
self.logon_status: bool = logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)},
"operating_status": 0,
}
if self.logon_status:
self.default_observation["logon_status"] = 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.logon_status:
obs["logon_status"] = 0
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
space_shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
}
if self.logon_status:
space_shape["logon_status"] = spaces.Discrete(3)
return spaces.Dict(space_shape)
@classmethod
def from_config(
cls,
config: Dict,
session: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> "NodeObservation":
"""Create node observation from a config. Also creates child service, folder and NIC observations.
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
observation size) , defaults to 2
:type num_services_per_node: int, optional
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
observation size) , defaults to 2
:type num_folders_per_node: int, optional
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed node observation
:rtype: NodeObservation
"""
node_uuid = session.ref_map_nodes[config["node_ref"]]
if parent_where is None:
where = ["network", "nodes", node_uuid]
else:
where = parent_where + ["nodes", node_uuid]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
services=services,
folders=folders,
nics=nics,
logon_status=logon_status,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_uuid>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
for i, rule_state in acl_state.items():
if rule_state is None:
obs[i + 1] = {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
obs[i + 1] = {
"position": i,
"permission": rule_state["action"],
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
"source_port": self.port_to_id[rule_state["src_port"]],
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
"dest_port": self.port_to_id[rule_state["dst_port"]],
"protocol": self.protocol_to_id[rule_state["protocol"]],
}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_ref"]
nic_num = ip_map_config["nic_num"]
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=session.options.ports,
protocols=session.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
num_rules=max_acl_rules,
)
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, session: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
return cls()
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
pass
class UC2BlueObservation(AbstractObservation):
"""Container for all observations used by the blue agent in UC2.
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
for the purpose of compiling several observation components.
"""
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where: Optional[List[str]] = None,
) -> None:
"""Initialise UC2 blue observation.
:param nodes: List of node observations
:type nodes: List[NodeObservation]
:param links: List of link observations
:type links: List[LinkObservation]
:param acl: The Access Control List observation
:type acl: AclObservation
:param ics: The ICS observation
:type ics: ICSObservation
:param where: Where in the simulation state dict to find information. Not used in this particular observation
because it only compiles other observations and doesn't contribute any new information, defaults to None
:type where: Optional[List[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs["ACL"] = self.acl.observe(state)
obs["ICS"] = self.ics.observe(state)
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]
num_nics_per_node = config["num_nics_per_node"]
nodes = [
NodeObservation.from_config(
config=n,
session=session,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
for n in node_configs
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, session=session) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, session=session)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, session=session)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
class UC2RedObservation(AbstractObservation):
"""Container for all observations used by the red agent in UC2."""
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
class UC2GreenObservation(NullObservation):
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
pass
class ObservationManager:
"""
Manage the observations of an Agent.
The observation space has the purpose of:
1. Reading the outputted state from the PrimAITE Simulation.
2. Selecting parts of the simulation state that are requested by the simulation config
3. Formatting this information so an agent can use it to make decisions.
"""
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
# refactor.
def __init__(self, observation: AbstractObservation) -> None:
"""Initialise observation space.
:param observation: Observation object
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.current_observation: ObsType
def update(self, state: Dict) -> Dict:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
"""
self.current_observation = self.obs.observe(state)
return self.current_observation
@property
def space(self) -> None:
"""Gymnasium space object describing the observation space shape."""
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
It should contain the key 'type' which selects which observation class to use (from a choice of:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
:type config: Dict
:param session: Reference to the PrimaiteSession object that spawned this observation.
:type session: PrimaiteSession
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
else:
raise ValueError("Observation space type invalid")