Files
PrimAITE/src/primaite/game/agent/observations.py
2023-10-11 15:49:41 +01:00

667 lines
27 KiB
Python

from abc import ABC, abstractmethod
from typing import Any, Dict, Hashable, List, Optional, Sequence, Tuple, TYPE_CHECKING
from gymnasium import spaces
from pydantic import BaseModel
from primaite.simulator.sim_container import Simulation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite import getLogger
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession
class AbstractObservation(ABC):
@abstractmethod
def observe(self, state: Dict) -> Any:
"""_summary_
:param state: _description_
:type state: Dict
:return: _description_
:rtype: Any
"""
...
@property
@abstractmethod
def space(self) -> spaces.Space:
"""Subclasses must define the shape that they expect"""
...
@classmethod
@abstractmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
"""Create this observation space component form a serialised format.
The `session` parameter is for a the PrimaiteSession object that spawns this component. During deserialisation,
a subclass of this class may need to translate from a 'reference' to a UUID.
"""
class FileObservation(AbstractObservation):
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
_summary_
:param where: Store information about where in the simulation state dictionary to find the relevatn information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"health_status": file_state["health_status"]}
@property
def space(self) -> spaces.Space:
return spaces.Dict({"health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where=None):
return cls(where=parent_where + ["files", config["file_name"]])
class ServiceObservation(AbstractObservation):
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'services', <service_uuid>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"operating_status": service_state["operating_state"], "health_status": service_state["health_status"]}
@property
def space(self) -> spaces.Space:
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]] = None):
return cls(where=parent_where + ["services", session.ref_map_services[config["service_ref"]].uuid])
class LinkObservation(AbstractObservation):
default_observation: spaces.Space = {"protocols": {"all": {"load": 0}}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_uuid>,'servics', <service_uuid>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 10) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"protocols": {"all": {"load": utilisation_category}}}
@property
def space(self) -> spaces.Space:
return spaces.Dict({"protocols": spaces.Dict({"all": spaces.Dict({"load": spaces.Discrete(11)})})})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
return cls(where=["network", "links", session.ref_map_links[config["link_ref"]]])
class FolderObservation(AbstractObservation):
def __init__(self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder:int=2) -> None:
"""Initialise folder Observation, including files inside of the folder.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_uuid>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
:type max_files: int, optional
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
that even if new files are created, the existing files will always occupy the same space in the observation
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
name, it will take the position defined in this dict. Defaults to {}
:type file_positions: Dict[int, str], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.files: List[FileObservation] = files
while len(self.files) < num_files_per_folder:
self.files.append(FileObservation())
while len(self.files)> num_files_per_folder:
truncated_file = self.files.pop()
msg = f"Too many files in folde observation. Truncating file {truncated_file}"
_LOGGER.warn(msg)
raise UserWarning(msg)
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
health_status = folder_state["health_status"]
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@property
def space(self) -> spaces.Space:
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]], num_files_per_folder:int=2):
where = parent_where + ["folders", config["folder_name"]]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, session=session, parent_where=where) for f in file_configs]
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
class NicObservation(AbstractObservation):
default_observation: spaces.Space = {"nic_status": 0}
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
return {"nic_status": 1 if nic_state["enabled"] else 2}
@property
def space(self) -> spaces.Space:
return spaces.Dict({"nic_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession", parent_where: Optional[List[str]]):
return cls(where=parent_where + ["NICs", config["nic_uuid"]])
class NodeObservation(AbstractObservation):
def __init__(
self,
where: Optional[Tuple[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
nics: List[NicObservation] = [],
logon_status: bool = False,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2
) -> None:
"""
Configurable observation for a node in the simulation.
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<node_uuid>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service UUID, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node, defaults to 2
:type max_services: int, optional
:param folders: Mapping between position in observation space and folder name, defaults to {}
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param nics: Mapping between position in observation space and NIC UUID, defaults to {}
:type nics: Dict[int,str], optional
:param max_nics: Max number of NICS in this node's obs space, defaults to 5
:type max_nics: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.services: List[ServiceObservation] = services
while len(self.services)<num_services_per_node:
# add an empty service observation without `where` parameter that will always return default (blank) observations
self.services.append(ServiceObservation())
while len(self.services)>num_services_per_node:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warn(msg)
raise UserWarning(msg)
# truncate service list
self.folders: List[FolderObservation] = folders
while len(self.folders) < num_folders_per_node:
# add an empty folder observation without `where` parameter that will always return default (blank) observations
self.folders.append(FolderObservation())
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
self.nics: List[NicObservation] = nics
self.logon_status: bool = logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.nics)},
"operating_status": 0,
}
if self.logon_status:
self.default_observation["logon_status"] = 0
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.logon_status:
obs["logon_status"] = 0
return obs
@property
def space(self) -> spaces.Space:
space_shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)}),
}
if self.logon_status:
space_shape["logon_status"] = spaces.Discrete(3)
return spaces.Dict(space_shape)
@classmethod
def from_config(
cls,
config: Dict,
session: "PrimaiteSession",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
) -> "NodeObservation":
node_uuid = session.ref_map_nodes[config["node_ref"]]
if parent_where is None:
where = ["network", "nodes", node_uuid]
else:
where = parent_where + ["nodes", node_uuid]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, session=session, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [FolderObservation.from_config(config=c, session=session, parent_where=where, num_files_per_folder=num_files_per_folder) for c in folder_configs]
nic_uuids = session.simulation.network.nodes[node_uuid].nics.keys()
nic_configs = [{"nic_uuid": n for n in nic_uuids}] if nic_uuids else []
nics = [NicObservation.from_config(config=c, session=session, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
services=services,
folders=folders,
nics=nics,
logon_status=logon_status,
num_services_per_node = num_services_per_node,
num_folders_per_node = num_folders_per_node,
num_files_per_folder = num_files_per_folder,
)
class AclObservation(AbstractObservation):
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: list[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
"RULES": {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
}
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
obs["RULES"] = {}
for i, rule_state in acl_state.items():
if rule_state is None:
obs["RULES"][i + 1] = {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
obs["RULES"][i + 1] = {
"position": i,
"permission": rule_state["action"],
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
"source_port": self.port_to_id[rule_state["src_port"]],
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
"dest_port": self.port_to_id[rule_state["dst_port"]],
"protocol": self.protocol_to_id[rule_state["protocol"]],
}
return obs
@property
def space(self) -> spaces.Space:
return spaces.Dict(
{
"RULES": spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "AclObservation":
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_ref"]
nic_num = ip_map_config["nic_num"]
node_obj = session.simulation.network.nodes[session.ref_map_nodes[node_ref]]
nic_obj = node_obj.ethernet_port[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_uuid = session.ref_map_nodes[config["router_node_ref"]]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=session.options.ports,
protocols=session.options.protocols,
where=["network", "nodes", router_uuid, "acl", "acl"],
)
class NullObservation(AbstractObservation):
def __init__(self, where: Optional[List[str]] = None):
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
return 0
@property
def space(self) -> spaces.Space:
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, session: Optional["PrimaiteSession"] = None) -> "NullObservation":
return cls()
class ICSObservation(NullObservation):
pass
class UC2BlueObservation(AbstractObservation):
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where: Optional[List[str]] = None,
) -> None:
super().__init__()
self.where: Optional[Tuple[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs["ACL"] = self.acl.observe(state)
obs["ICS"] = self.ics.observe(state)
return obs
@property
def space(self) -> spaces.Space:
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]
nodes = [NodeObservation.from_config(
config=n,
session=session,
num_services_per_node= num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
) for n in node_configs]
link_configs = config["links"]
links = [LinkObservation.from_config(config=l, session=session) for l in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, session=session)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, session=session)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
class UC2RedObservation(AbstractObservation):
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
}
def observe(self, state: Dict) -> Dict:
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
return obs
@property
def space(self) -> spaces.Space:
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
}
)
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession"):
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, session=session) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
class UC2GreenObservation(NullObservation):
pass
class ObservationSpace:
"""
Manage the observations of an Actor.
The observation space has the purpose of:
1. Reading the outputted state from the PrimAITE Simulation.
2. Selecting parts of the simulation state that are requested by the simulation config
3. Formatting this information so an actor can use it to make decisions.
"""
...
# what this class does:
# keep a list of observations
# create observations for an actor from the config
def __init__(self, observation: AbstractObservation) -> None:
self.obs: AbstractObservation = observation
def observe(self, state) -> Dict:
return self.obs.observe(state)
@property
def space(self) -> None:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), session=session))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), session=session))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), session=session))
else:
raise ValueError("Observation space type invalid")