#2417 Move classes to correct files

This commit is contained in:
Marek Wolan
2024-03-29 14:14:03 +00:00
parent 9123aff592
commit 22e1dfea2f
10 changed files with 1332 additions and 1913 deletions

View File

@@ -0,0 +1,187 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
class ACLObservation(AbstractObservation, identifier="ACL"):
"""ACL observation, provides information about access control lists within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ACLObservation."""
ip_list: Optional[List[IPv4Address]] = None
"""List of IP addresses."""
wildcard_list: Optional[List[str]] = None
"""List of wildcard strings."""
port_list: Optional[List[int]] = None
"""List of port numbers."""
protocol_list: Optional[List[str]] = None
"""List of protocol names."""
num_rules: Optional[int] = None
"""Number of ACL rules."""
def __init__(
self,
where: WhereType,
num_rules: int,
ip_list: List[IPv4Address],
wildcard_list: List[str],
port_list: List[int],
protocol_list: List[str],
) -> None:
"""
Initialize an ACL observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
:type where: WhereType
:param num_rules: Number of ACL rules.
:type num_rules: int
:param ip_list: List of IP addresses.
:type ip_list: List[IPv4Address]
:param wildcard_list: List of wildcard strings.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param protocol_list: List of protocol names.
:type protocol_list: List[str]
"""
self.where = where
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)}
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_ip_id": 0,
"source_wildcard_id": 0,
"source_port_id": 0,
"dest_ip_id": 0,
"dest_wildcard_id": 0,
"dest_port_id": 0,
"protocol_id": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing ACL rules.
:rtype: ObsType
"""
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_ip_id": 0,
"source_wildcard_id": 0,
"source_port_id": 0,
"dest_ip_id": 0,
"dest_wildcard_id": 0,
"dest_port_id": 0,
"protocol_id": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = self.ip_to_id.get(src_ip, 1)
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = self.ip_to_id.get(dst_ip, 1)
src_wildcard = rule_state["source_wildcard_id"]
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
dst_wildcard = rule_state["dest_wildcard_id"]
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
src_port = rule_state["source_port_id"]
src_port_id = self.port_to_id.get(src_port, 1)
dst_port = rule_state["dest_port_id"]
dst_port_id = self.port_to_id.get(dst_port, 1)
protocol = rule_state["protocol"]
protocol_id = self.protocol_to_id.get(protocol, 1)
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_ip_id": src_node_id,
"source_wildcard_id": src_wildcard_id,
"source_port_id": src_port_id,
"dest_ip_id": dst_node_ip,
"dest_wildcard_id": dst_wildcard_id,
"dest_port_id": dst_port_id,
"protocol_id": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for ACL rules.
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
"protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
"""
Create an ACL observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the ACL observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this ACL's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed ACL observation instance.
:rtype: ACLObservation
"""
return cls(
where=parent_where + ["acl", "acl"],
num_rules=config.num_rules,
ip_list=config.ip_list,
wildcard_list=config.wildcard_list,
port_list=config.port_list,
protocol_list=config.protocol_list,
)

View File

@@ -1,138 +0,0 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.game.agent.observations.observations import (
AbstractObservation,
AclObservation,
ICSObservation,
LinkObservation,
NullObservation,
)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class UC2BlueObservation(AbstractObservation):
"""Container for all observations used by the blue agent in UC2.
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
for the purpose of compiling several observation components.
"""
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where: Optional[List[str]] = None,
) -> None:
"""Initialise UC2 blue observation.
:param nodes: List of node observations
:type nodes: List[NodeObservation]
:param links: List of link observations
:type links: List[LinkObservation]
:param acl: The Access Control List observation
:type acl: AclObservation
:param ics: The ICS observation
:type ics: ICSObservation
:param where: Where in the simulation state dict to find information. Not used in this particular observation
because it only compiles other observations and doesn't contribute any new information, defaults to None
:type where: Optional[List[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs["ACL"] = self.acl.observe(state)
obs["ICS"] = self.ics.observe(state)
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]
num_nics_per_node = config["num_nics_per_node"]
nodes = [
NodeObservation.from_config(
config=n,
game=game,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
for n in node_configs
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, game=game)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, game=game)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new

View File

@@ -1,107 +1,130 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, Iterable, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class FileObservation(AbstractObservation, identifier="FILE"):
"""File observation, provides status information about a file within the simulation environment."""
class FileObservation(AbstractObservation):
"""Observation of a file on a node in the network."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FileObservation."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
file_name: str
"""Name of the file, used for querying simulation state dictionary."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to the file in the observation."""
def __init__(self, where: WhereType, include_num_access: bool) -> None:
"""
Initialise file observation.
Initialize a file observation instance.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
:param where: Where in the simulation state dictionary to find the relevant information for this file.
A typical location for a file might be
['network', 'nodes', <node_hostname>, 'file_system', 'folder', <folder_name>, 'files', <file_name>].
:type where: WhereType
:param include_num_access: Whether to include the number of accesses to the file in the observation.
:type include_num_access: bool
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
self.where: WhereType = where
self.include_num_access: bool = include_num_access
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
self.default_observation: ObsType = {"health_status": 0}
if self.include_num_access:
self.default_observation["num_access"] = 0
:param state: Simulation state dictionary
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing the health status of the file and optionally the number of accesses.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"health_status": file_state["visible_status"]}
obs = {"health_status": file_state["visible_status"]}
if self.include_num_access:
obs["num_access"] = file_state["num_access"]
# raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.")
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:return: Gymnasium space representing the observation space for file status.
:rtype: spaces.Space
"""
return spaces.Dict({"health_status": spaces.Discrete(6)})
space = {"health_status": spaces.Discrete(6)}
if self.include_num_access:
space["num_access"] = spaces.Discrete(4)
return spaces.Dict(space)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param game: _description_
:type game: PrimaiteGame
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
:rtype: _type_
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
"""
return cls(where=parent_where + ["files", config["file_name"]])
Create a file observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the file observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this file's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed file observation instance.
:rtype: FileObservation
"""
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
class FolderObservation(AbstractObservation):
"""Folder observation, including files inside of the folder."""
class FolderObservation(AbstractObservation, identifier="FOLDER"):
"""Folder observation, provides status information about a folder within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FolderObservation."""
folder_name: str
"""Name of the folder, used for querying simulation state dictionary."""
files: List[FileObservation.ConfigSchema] = []
"""List of file configurations within the folder."""
num_files: Optional[int] = None
"""Number of spaces for file observations in this folder."""
include_num_access: Optional[bool] = None
"""Whether files in this folder should include the number of accesses in their observation."""
def __init__(
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
) -> None:
"""Initialise folder Observation, including files inside the folder.
"""
Initialize a folder observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
:type max_files: int, optional
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
that even if new files are created, the existing files will always occupy the same space in the observation
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
name, it will take the position defined in this dict. Defaults to {}
:type file_positions: Dict[int, str], optional
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].
:type where: WhereType
:param files: List of file observation instances within the folder.
:type files: Iterable[FileObservation]
:param num_files: Number of files expected in the folder.
:type num_files: int
:param include_num_access: Whether to include the number of accesses to files in the observation.
:type include_num_access: bool
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.where: WhereType = where
self.files: List[FileObservation] = files
while len(self.files) < num_files_per_folder:
self.files.append(FileObservation())
while len(self.files) > num_files_per_folder:
while len(self.files) < num_files:
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
while len(self.files) > num_files:
truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
_LOGGER.warning(msg)
@@ -111,16 +134,15 @@ class FolderObservation(AbstractObservation):
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the health status of the folder and status of files within the folder.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
@@ -136,9 +158,10 @@ class FolderObservation(AbstractObservation):
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:return: Gymnasium space representing the observation space for folder status.
:rtype: spaces.Space
"""
return spaces.Dict(
@@ -149,29 +172,23 @@ class FolderObservation(AbstractObservation):
)
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
"""
Create a folder observation from a configuration schema.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param config: Configuration schema containing the necessary information for the folder observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_hostname>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed folder observation
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed folder observation instance.
:rtype: FolderObservation
"""
where = parent_where + ["folders", config["folder_name"]]
where = parent_where + ["folders", config.folder_name]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
# pass down shared/common config items
for file_config in config.files:
file_config.include_num_access = config.include_num_access
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)

View File

@@ -0,0 +1,213 @@
from __future__ import annotations
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
_LOGGER = getLogger(__name__)
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""Firewall observation, provides status information about a firewall within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FirewallObservation."""
hostname: str
"""Hostname of the firewall node, used for querying simulation state dictionary."""
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
def __init__(
self,
where: WhereType,
ip_list: List[str],
wildcard_list: List[str],
port_list: List[int],
protocol_list: List[str],
num_rules: int,
) -> None:
"""
Initialize a firewall observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this firewall.
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
:type where: WhereType
:param ip_list: List of IP addresses.
:type ip_list: List[str]
:param wildcard_list: List of wildcard rules.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param protocol_list: List of protocol types.
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
:type num_rules: int
"""
self.where: WhereType = where
self.ports: List[PortObservation] = [
PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3)
]
# TODO: check what the port nums are for firewall.
self.internal_inbound_acl = ACLObservation(
where=self.where + ["acl", "internal", "inbound"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.internal_outbound_acl = ACLObservation(
where=self.where + ["acl", "internal", "outbound"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.dmz_inbound_acl = ACLObservation(
where=self.where + ["acl", "dmz", "inbound"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.dmz_outbound_acl = ACLObservation(
where=self.where + ["acl", "dmz", "outbound"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.external_inbound_acl = ACLObservation(
where=self.where + ["acl", "external", "inbound"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.external_outbound_acl = ACLObservation(
where=self.where + ["acl", "external", "outbound"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.default_observation = {
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.default_observation,
"OUTBOUND": self.internal_outbound_acl.default_observation,
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.default_observation,
"OUTBOUND": self.dmz_outbound_acl.default_observation,
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.default_observation,
"OUTBOUND": self.external_outbound_acl.default_observation,
},
}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
:rtype: ObsType
"""
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.observe(state),
"OUTBOUND": self.internal_outbound_acl.observe(state),
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.observe(state),
"OUTBOUND": self.dmz_outbound_acl.observe(state),
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.observe(state),
"OUTBOUND": self.external_outbound_acl.observe(state),
},
}
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for firewall status.
:rtype: spaces.Space
"""
space = spaces.Dict(
{
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
"INTERNAL": spaces.Dict(
{
"INBOUND": self.internal_inbound_acl.space,
"OUTBOUND": self.internal_outbound_acl.space,
}
),
"DMZ": spaces.Dict(
{
"INBOUND": self.dmz_inbound_acl.space,
"OUTBOUND": self.dmz_outbound_acl.space,
}
),
"EXTERNAL": spaces.Dict(
{
"INBOUND": self.external_inbound_acl.space,
"OUTBOUND": self.external_outbound_acl.space,
}
),
}
)
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
"""
Create a firewall observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the firewall observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this firewall's
parent node. A typical location for a node might be ['network', 'nodes', <firewall_hostname>].
:type parent_where: WhereType, optional
:return: Constructed firewall observation instance.
:rtype: FirewallObservation
"""
where = parent_where + ["nodes", config.hostname]
return cls(
where=where,
ip_list=config.ip_list,
wildcard_list=config.wildcard_list,
port_list=config.port_list,
protocol_list=config.protocol_list,
num_rules=config.num_rules,
)

View File

@@ -0,0 +1,229 @@
from __future__ import annotations
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.file_system_observations import FolderObservation
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
class HostObservation(AbstractObservation, identifier="HOST"):
"""Host observation, provides status information about a host within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for HostObservation."""
hostname: str
"""Hostname of the host, used for querying simulation state dictionary."""
services: List[ServiceObservation.ConfigSchema] = []
"""List of services to observe on the host."""
applications: List[ApplicationObservation.ConfigSchema] = []
"""List of applications to observe on the host."""
folders: List[FolderObservation.ConfigSchema] = []
"""List of folders to observe on the host."""
network_interfaces: List[NICObservation.ConfigSchema] = []
"""List of network interfaces to observe on the host."""
num_services: Optional[int] = None
"""Number of spaces for service observations on this host."""
num_applications: Optional[int] = None
"""Number of spaces for application observations on this host."""
num_folders: Optional[int] = None
"""Number of spaces for folder observations on this host."""
num_files: Optional[int] = None
"""Number of spaces for file observations on this host."""
num_nics: Optional[int] = None
"""Number of spaces for network interface observations on this host."""
include_nmne: Optional[bool] = None
"""Whether network interface observations should include number of malicious network events."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to files observations on this host."""
def __init__(
self,
where: WhereType,
services: List[ServiceObservation],
applications: List[ApplicationObservation],
folders: List[FolderObservation],
network_interfaces: List[NICObservation],
num_services: int,
num_applications: int,
num_folders: int,
num_files: int,
num_nics: int,
include_nmne: bool,
include_num_access: bool,
) -> None:
"""
Initialize a host observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this host.
A typical location for a host might be ['network', 'nodes', <hostname>].
:type where: WhereType
:param services: List of service observations on the host.
:type services: List[ServiceObservation]
:param applications: List of application observations on the host.
:type applications: List[ApplicationObservation]
:param folders: List of folder observations on the host.
:type folders: List[FolderObservation]
:param network_interfaces: List of network interface observations on the host.
:type network_interfaces: List[NICObservation]
:param num_services: Number of services to observe.
:type num_services: int
:param num_applications: Number of applications to observe.
:type num_applications: int
:param num_folders: Number of folders to observe.
:type num_folders: int
:param num_files: Number of files.
:type num_files: int
:param num_nics: Number of network interfaces.
:type num_nics: int
:param include_nmne: Flag to include network metrics and errors.
:type include_nmne: bool
:param include_num_access: Flag to include the number of accesses to files.
:type include_num_access: bool
"""
self.where: WhereType = where
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
while len(self.services) < num_services:
self.services.append(ServiceObservation(where=None))
while len(self.services) > num_services:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warning(msg)
self.applications: List[ApplicationObservation] = applications
while len(self.applications) < num_applications:
self.applications.append(ApplicationObservation(where=None))
while len(self.applications) > num_applications:
truncated_application = self.applications.pop()
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
_LOGGER.warning(msg)
self.folders: List[FolderObservation] = folders
while len(self.folders) < num_folders:
self.folders.append(
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
)
while len(self.folders) > num_folders:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
_LOGGER.warning(msg)
self.network_interfaces: List[NICObservation] = network_interfaces
while len(self.network_interfaces) < num_nics:
self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne))
while len(self.network_interfaces) > num_nics:
truncated_nic = self.network_interfaces.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
_LOGGER.warning(msg)
self.default_observation: ObsType = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
"num_file_creations": 0,
"num_file_deletions": 0,
}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status information about the host.
:rtype: ObsType
"""
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
}
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for host status.
:rtype: spaces.Space
"""
shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
"num_file_creations": spaces.Discrete(4),
"num_file_deletions": spaces.Discrete(4),
}
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation:
"""
Create a host observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the host observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this host.
A typical location might be ['network', 'nodes', <hostname>].
:type parent_where: WhereType, optional
:return: Constructed host observation instance.
:rtype: HostObservation
"""
if parent_where is None:
where = ["network", "nodes", config.hostname]
else:
where = parent_where + ["nodes", config.hostname]
# Pass down shared/common config items
for folder_config in config.folders:
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
return cls(
where=where,
services=services,
applications=applications,
folders=folders,
network_interfaces=nics,
num_services=config.num_services,
num_applications=config.num_applications,
num_folders=config.num_folders,
num_files=config.num_files,
num_nics=config.num_nics,
include_nmne=config.include_nmne,
include_num_access=config.include_num_access,
)

View File

@@ -1,188 +1,157 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import CAPTURE_NMNE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
"""Status information about a network interface within the simulation environment."""
low_nmne_threshold: int = 0
"""The minimum number of malicious network events to be considered low."""
med_nmne_threshold: int = 5
"""The minimum number of malicious network events to be considered medium."""
high_nmne_threshold: int = 10
"""The minimum number of malicious network events to be considered high."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NICObservation."""
global CAPTURE_NMNE
nic_num: int
"""Number of the network interface."""
include_nmne: Optional[bool] = None
"""Whether to include number of malicious network events (NMNE) in the observation."""
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
return data
def __init__(
self,
where: Optional[Tuple[str]] = None,
low_nmne_threshold: Optional[int] = 0,
med_nmne_threshold: Optional[int] = 5,
high_nmne_threshold: Optional[int] = 10,
) -> None:
"""Initialise NIC observation.
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_hostname>,'NICs',<nic_number>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
def __init__(self, where: WhereType, include_nmne: bool) -> None:
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialize a network interface observation instance.
global CAPTURE_NMNE
if CAPTURE_NMNE:
self.nmne_inbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
self.nmne_outbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
self._validate_nmne_categories(
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)
def _validate_nmne_categories(
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
):
:param where: Where in the simulation state dictionary to find the relevant information for this interface.
A typical location for a network interface might be
['network', 'nodes', <node_hostname>, 'NICs', <nic_num>].
:type where: WhereType
:param include_nmne: Flag to determine whether to include NMNE information in the observation.
:type include_nmne: bool
"""
Validates the nmne threshold config.
self.where = where
self.include_nmne: bool = include_nmne
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
self.default_observation: ObsType = {"nic_status": 0}
if self.include_nmne:
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
def observe(self, state: Dict) -> ObsType:
"""
if high_nmne_threshold <= med_nmne_threshold:
raise Exception(
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
f"than medium nmne count ({med_nmne_threshold})"
)
Generate observation based on the current state of the simulation.
if med_nmne_threshold <= low_nmne_threshold:
raise Exception(
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
f"than low nmne count ({low_nmne_threshold})"
)
self.high_nmne_threshold = high_nmne_threshold
self.med_nmne_threshold = med_nmne_threshold
self.low_nmne_threshold = low_nmne_threshold
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
Bins are defined as follows:
- 0: No MNEs detected (0 events).
- 1: Low number of MNEs (default 1-5 events).
- 2: Moderate number of MNEs (default 6-10 events).
- 3: High number of MNEs (default more than 10 events).
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > self.high_nmne_threshold:
return 3
elif nmne_count > self.med_nmne_threshold:
return 2
elif nmne_count > self.low_nmne_threshold:
return 1
return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing the status of the network interface and optionally NMNE information.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
if CAPTURE_NMNE:
obs_dict.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs_dict
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
if self.include_nmne:
obs.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for network interface status and NMNE information.
:rtype: spaces.Space
"""
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
if CAPTURE_NMNE:
if self.include_nmne:
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
return space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation:
"""
low_nmne_threshold = None
med_nmne_threshold = None
high_nmne_threshold = None
Create a network interface observation from a configuration schema.
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
threshold = game.options.thresholds["nmne"]
:param config: Configuration schema containing the necessary information for the network interface observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed network interface observation instance.
:rtype: NICObservation
"""
return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne)
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
return cls(
where=parent_where + ["NICs", config["nic_num"]],
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)
class PortObservation(AbstractObservation, identifier="PORT"):
"""Port observation, provides status information about a network port within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for PortObservation."""
port_id: int
"""Identifier of the port, used for querying simulation state dictionary."""
def __init__(self, where: WhereType) -> None:
"""
Initialize a port observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this port.
A typical location for a port might be ['network', 'nodes', <node_hostname>, 'NICs', <port_id>].
:type where: WhereType
"""
self.where = where
self.default_observation: ObsType = {"operating_status": 0}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the operating status of the port.
:rtype: ObsType
"""
port_state = access_from_nested_dict(state, self.where)
if port_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"operating_status": 1 if port_state["enabled"] else 2}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for port status.
:rtype: spaces.Space
"""
return spaces.Dict({"operating_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation:
"""
Create a port observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the port observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this port's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed port observation instance.
:rtype: PortObservation
"""
return cls(where=parent_where + ["NICs", config.port_id])

File diff suppressed because it is too large Load Diff

View File

@@ -1,24 +1,23 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Type
from typing import Any, Dict, Iterable, Type
from gymnasium import spaces
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
WhereType = Iterable[str | int] | None
class AbstractObservation(ABC):
"""Abstract class for an observation space component."""
class ConfigSchema(ABC, BaseModel):
"""Config schema for observations."""
model_config = ConfigDict(extra="forbid")
_registry: Dict[str, Type["AbstractObservation"]] = {}
@@ -61,269 +60,271 @@ class AbstractObservation(ABC):
@classmethod
def from_config(cls, cfg: Dict) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
ObservationType = cls._registry[cfg['type']]
ObservationType = cls._registry[cfg["type"]]
return ObservationType.from_config(cfg=cfg)
# class LinkObservation(AbstractObservation):
# """Observation of a link in the network."""
'''
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
# default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
# "Default observation is what should be returned when the link doesn't exist."
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
# def __init__(self, where: Optional[Tuple[str]] = None) -> None:
# """Initialise link observation.
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
# :param where: Store information about where in the simulation state dictionary to find the relevant information.
# Optional. If None, this corresponds that the file does not exist and the observation will be populated with
# zeroes.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
# A typical location for a service looks like this:
# `['network','nodes',<node_hostname>,'servics', <service_name>]`
# :type where: Optional[List[str]]
# """
# super().__init__()
# self.where: Optional[Tuple[str]] = where
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
# def observe(self, state: Dict) -> Dict:
# """Generate observation based on the current state of the simulation.
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
# :param state: Simulation state dictionary
# :type state: Dict
# :return: Observation
# :rtype: Dict
# """
# if self.where is None:
# return self.default_observation
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
# link_state = access_from_nested_dict(state, self.where)
# if link_state is NOT_PRESENT_IN_STATE:
# return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# bandwidth = link_state["bandwidth"]
# load = link_state["current_load"]
# if load == 0:
# utilisation_category = 0
# else:
# utilisation_fraction = load / bandwidth
# # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
# utilisation_category = int(utilisation_fraction * 9) + 1
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# # TODO: once the links support separte load per protocol, this needs amendment to reflect that.
# return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
# @property
# def space(self) -> spaces.Space:
# """Gymnasium space object describing the observation space shape.
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
# :return: Gymnasium space
# :rtype: spaces.Space
# """
# return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
# @classmethod
# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
# """Create link observation from a config.
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
# :param config: Dictionary containing the configuration for this link observation.
# :type config: Dict
# :param game: Reference to the PrimaiteGame object that spawned this observation.
# :type game: PrimaiteGame
# :return: Constructed link observation
# :rtype: LinkObservation
# """
# return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
# class AclObservation(AbstractObservation):
# """Observation of an Access Control List (ACL) in the network."""
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# # TODO: should where be optional, and we can use where=None to pad the observation space?
# # definitely the current approach does not support tracking files that aren't specified by name, for example
# # if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# # this needs adding, but not for the MVP.
# def __init__(
# self,
# node_ip_to_id: Dict[str, int],
# ports: List[int],
# protocols: List[str],
# where: Optional[Tuple[str]] = None,
# num_rules: int = 10,
# ) -> None:
# """Initialise ACL observation.
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
# :param node_ip_to_id: Mapping between IP address and ID.
# :type node_ip_to_id: Dict[str, int]
# :param ports: List of ports which are part of the game that define the ordering when converting to an ID
# :type ports: List[int]
# :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
# :type protocols: list[str]
# :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
# example may look like this:
# ['network','nodes',<router_hostname>,'acl','acl']
# :type where: Optional[Tuple[str]], optional
# :param num_rules: , defaults to 10
# :type num_rules: int, optional
# """
# super().__init__()
# self.where: Optional[Tuple[str]] = where
# self.num_rules: int = num_rules
# self.node_to_id: Dict[str, int] = node_ip_to_id
# "List of node IP addresses, order in this list determines how they are converted to an ID"
# self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
# "List of ports which are part of the game that define the ordering when converting to an ID"
# self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
# "List of protocols which are part of the game, defines ordering when converting to an ID"
# self.default_observation: Dict = {
# i
# + 1: {
# "position": i,
# "permission": 0,
# "source_node_id": 0,
# "source_port": 0,
# "dest_node_id": 0,
# "dest_port": 0,
# "protocol": 0,
# }
# for i in range(self.num_rules)
# }
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
# def observe(self, state: Dict) -> Dict:
# """Generate observation based on the current state of the simulation.
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
# :param state: Simulation state dictionary
# :type state: Dict
# :return: Observation
# :rtype: Dict
# """
# if self.where is None:
# return self.default_observation
# acl_state: Dict = access_from_nested_dict(state, self.where)
# if acl_state is NOT_PRESENT_IN_STATE:
# return self.default_observation
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# # TODO: what if the ACL has more rules than num of max rules for obs space
# obs = {}
# acl_items = dict(acl_state.items())
# i = 1 # don't show rule 0 for compatibility reasons.
# while i < self.num_rules + 1:
# rule_state = acl_items[i]
# if rule_state is None:
# obs[i] = {
# "position": i - 1,
# "permission": 0,
# "source_node_id": 0,
# "source_port": 0,
# "dest_node_id": 0,
# "dest_port": 0,
# "protocol": 0,
# }
# else:
# src_ip = rule_state["src_ip_address"]
# src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
# dst_ip = rule_state["dst_ip_address"]
# dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
# src_port = rule_state["src_port"]
# src_port_id = 1 if src_port is None else self.port_to_id[src_port]
# dst_port = rule_state["dst_port"]
# dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
# protocol = rule_state["protocol"]
# protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
# obs[i] = {
# "position": i - 1,
# "permission": rule_state["action"],
# "source_node_id": src_node_id,
# "source_port": src_port_id,
# "dest_node_id": dst_node_ip,
# "dest_port": dst_port_id,
# "protocol": protocol_id,
# }
# i += 1
# return obs
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
# @property
# def space(self) -> spaces.Space:
# """Gymnasium space object describing the observation space shape.
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
# :return: Gymnasium space
# :rtype: spaces.Space
# """
# return spaces.Dict(
# {
# i
# + 1: spaces.Dict(
# {
# "position": spaces.Discrete(self.num_rules),
# "permission": spaces.Discrete(3),
# # adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
# "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
# "source_port": spaces.Discrete(len(self.port_to_id) + 2),
# "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
# "dest_port": spaces.Discrete(len(self.port_to_id) + 2),
# "protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
# }
# )
# for i in range(self.num_rules)
# }
# )
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
# @classmethod
# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
# """Generate ACL observation from a config.
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
# :param config: Dictionary containing the configuration for this ACL observation.
# :type config: Dict
# :param game: Reference to the PrimaiteGame object that spawned this observation.
# :type game: PrimaiteGame
# :return: Observation object
# :rtype: AclObservation
# """
# max_acl_rules = config["options"]["max_acl_rules"]
# node_ip_to_idx = {}
# for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
# node_ref = ip_map_config["node_hostname"]
# nic_num = ip_map_config["nic_num"]
# node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
# nic_obj = node_obj.network_interface[nic_num]
# node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_hostname"]
nic_num = ip_map_config["nic_num"]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.network_interface[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
# router_hostname = config["router_hostname"]
# return cls(
# node_ip_to_id=node_ip_to_idx,
# ports=game.options.ports,
# protocols=game.options.protocols,
# where=["network", "nodes", router_hostname, "acl", "acl"],
# num_rules=max_acl_rules,
# )
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
# class NullObservation(AbstractObservation):
# """Null observation, returns a single 0 value for the observation space."""
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
# def __init__(self, where: Optional[List[str]] = None):
# """Initialise null observation."""
# self.default_observation: Dict = {}
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
# def observe(self, state: Dict) -> Dict:
# """Generate observation based on the current state of the simulation."""
# return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
# @property
# def space(self) -> spaces.Space:
# """Gymnasium space object describing the observation space shape."""
# return spaces.Discrete(1)
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
# @classmethod
# def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
# """
# Create null observation from a config.
@classmethod
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
# The parameters are ignored, they are here to match the signature of the other observation classes.
# """
# return cls()
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
return cls()
# class ICSObservation(NullObservation):
# """ICS observation placeholder, currently not implemented so always returns a single 0."""
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
# pass
pass
'''

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
from typing import Dict, List, Optional
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""Router observation, provides status information about a router within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for RouterObservation."""
hostname: str
"""Hostname of the router, used for querying simulation state dictionary."""
ports: Optional[List[PortObservation.ConfigSchema]] = None
"""Configuration of port observations for this router."""
num_ports: Optional[int] = None
"""Number of port observations configured for this router."""
acl: Optional[ACLObservation.ConfigSchema] = None
"""Configuration of ACL observation on this router."""
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
def __init__(
self,
where: WhereType,
ports: List[PortObservation],
num_ports: int,
acl: ACLObservation,
) -> None:
"""
Initialize a router observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this router.
A typical location for a router might be ['network', 'nodes', <node_hostname>].
:type where: WhereType
:param ports: List of port observations representing the ports of the router.
:type ports: List[PortObservation]
:param num_ports: Number of ports for the router.
:type num_ports: int
:param acl: ACL observation representing the access control list of the router.
:type acl: ACLObservation
"""
self.where: WhereType = where
self.ports: List[PortObservation] = ports
self.acl: ACLObservation = acl
self.num_ports: int = num_ports
while len(self.ports) < num_ports:
self.ports.append(PortObservation(where=None))
while len(self.ports) > num_ports:
self.ports.pop()
msg = "Too many ports in router observation. Truncating."
_LOGGER.warning(msg)
self.default_observation = {
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
"ACL": self.acl.default_observation,
}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status of ports and ACL configuration of the router.
:rtype: ObsType
"""
router_state = access_from_nested_dict(state, self.where)
if router_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
obs["ACL"] = self.acl.observe(state)
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for router status.
:rtype: spaces.Space
"""
return spaces.Dict(
{"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space}
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation:
"""
Create a router observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the router observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this router's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed router observation instance.
:rtype: RouterObservation
"""
where = parent_where + ["nodes", config.hostname]
if config.acl is None:
config.acl = ACLObservation.ConfigSchema()
if config.acl.num_rules is None:
config.acl.num_rules = config.num_rules
if config.acl.ip_list is None:
config.acl.ip_list = config.ip_list
if config.acl.wildcard_list is None:
config.acl.wildcard_list = config.wildcard_list
if config.acl.port_list is None:
config.acl.port_list = config.port_list
if config.acl.protocol_list is None:
config.acl.protocol_list = config.protocol_list
if config.ports is None:
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)

View File

@@ -1,45 +1,43 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
"""Service observation, shows status of a service in the simulation environment."""
class ServiceObservation(AbstractObservation):
"""Observation of a service in the network."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ServiceObservation."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
service_name: str
"""Name of the service, used for querying simulation state dictionary"""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise service observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'services', <service_name>]`
:type where: Optional[List[str]]
def __init__(self, where: WhereType) -> None:
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialize a service observation instance.
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param where: Where in the simulation state dictionary to find the relevant information for this service.
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].
:type where: WhereType
"""
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0}
:param state: Simulation state dictionary
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing the operating status and health status of the service.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
@@ -50,114 +48,96 @@ class ServiceObservation(AbstractObservation):
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for service status.
:rtype: spaces.Space
"""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
"""
Create a service observation from a configuration schema.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:param config: Configuration schema containing the necessary information for the service observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this service's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed service observation instance.
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", config["service_name"]])
return cls(where=parent_where + ["services", config.service_name])
class ApplicationObservation(AbstractObservation):
"""Observation of an application in the network."""
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
"""Application observation, shows the status of an application within the simulation environment."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0}
"Default observation is what should be returned when the application doesn't exist."
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ApplicationObservation."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise application observation.
application_name: str
"""Name of the application, used for querying simulation state dictionary"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'applications', <application_name>]`
:type where: Optional[List[str]]
def __init__(self, where: WhereType) -> None:
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialise an application observation instance.
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param where: Where in the simulation state dictionary to find the relevant information for this application.
A typical location for an application might be
['network', 'nodes', <node_hostname>, 'applications', <application_name>].
:type where: WhereType
"""
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
:param state: Simulation state dictionary
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Obs containing the operating status, health status, and number of executions of the application.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
app_state = access_from_nested_dict(state, self.where)
if app_state is NOT_PRESENT_IN_STATE:
application_state = access_from_nested_dict(state, self.where)
if application_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {
"operating_status": app_state["operating_state"],
"health_status": app_state["health_state_visible"],
"num_executions": self._categorise_num_executions(app_state["num_executions"]),
"operating_status": application_state["operating_state"],
"health_status": application_state["health_state_visible"],
"num_executions": application_state["num_executions"],
}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for application status.
:rtype: spaces.Space
"""
return spaces.Dict(
{
"operating_status": spaces.Discrete(7),
"health_status": spaces.Discrete(6),
"health_status": spaces.Discrete(5),
"num_executions": spaces.Discrete(4),
}
)
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ApplicationObservation":
"""Create application observation from a config.
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
"""
Create an application observation from a configuration schema.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:param config: Configuration schema containing the necessary information for the application observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this application's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed application observation instance.
:rtype: ApplicationObservation
"""
return cls(where=parent_where + ["services", config["application_name"]])
@classmethod
def _categorise_num_executions(cls, num_executions: int) -> int:
"""
Categorise the number of executions of an application.
Helps classify the number of application executions into different categories.
Current categories:
- 0: Application is never executed
- 1: Application is executed a low number of times (1-5)
- 2: Application is executed often (6-10)
- 3: Application is executed a high number of times (more than 10)
:param: num_executions: Number of times the application is executed
"""
if num_executions > 10:
return 3
elif num_executions > 5:
return 2
elif num_executions > 0:
return 1
return 0
return cls(where=parent_where + ["applications", config.application_name])