#2417 Move classes to correct files
This commit is contained in:
187
src/primaite/game/agent/observations/acl_observation.py
Normal file
187
src/primaite/game/agent/observations/acl_observation.py
Normal file
@@ -0,0 +1,187 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"""ACL observation, provides information about access control lists within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ACLObservation."""
|
||||
|
||||
ip_list: Optional[List[IPv4Address]] = None
|
||||
"""List of IP addresses."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of wildcard strings."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of port numbers."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocol names."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of ACL rules."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
num_rules: int,
|
||||
ip_list: List[IPv4Address],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
Initialize an ACL observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
|
||||
:type where: WhereType
|
||||
:param num_rules: Number of ACL rules.
|
||||
:type num_rules: int
|
||||
:param ip_list: List of IP addresses.
|
||||
:type ip_list: List[IPv4Address]
|
||||
:param wildcard_list: List of wildcard strings.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param protocol_list: List of protocol names.
|
||||
:type protocol_list: List[str]
|
||||
"""
|
||||
self.where = where
|
||||
self.num_rules: int = num_rules
|
||||
self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)}
|
||||
self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)}
|
||||
self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
"source_port_id": 0,
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing ACL rules.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
"source_port_id": 0,
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = self.ip_to_id.get(src_ip, 1)
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = self.ip_to_id.get(dst_ip, 1)
|
||||
src_wildcard = rule_state["source_wildcard_id"]
|
||||
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
|
||||
dst_wildcard = rule_state["dest_wildcard_id"]
|
||||
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
|
||||
src_port = rule_state["source_port_id"]
|
||||
src_port_id = self.port_to_id.get(src_port, 1)
|
||||
dst_port = rule_state["dest_port_id"]
|
||||
dst_port_id = self.port_to_id.get(dst_port, 1)
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_ip_id": src_node_id,
|
||||
"source_wildcard_id": src_wildcard_id,
|
||||
"source_port_id": src_port_id,
|
||||
"dest_ip_id": dst_node_ip,
|
||||
"dest_wildcard_id": dst_wildcard_id,
|
||||
"dest_port_id": dst_port_id,
|
||||
"protocol_id": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for ACL rules.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
|
||||
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
|
||||
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
|
||||
"""
|
||||
Create an ACL observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the ACL observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this ACL's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed ACL observation instance.
|
||||
:rtype: ACLObservation
|
||||
"""
|
||||
return cls(
|
||||
where=parent_where + ["acl", "acl"],
|
||||
num_rules=config.num_rules,
|
||||
ip_list=config.ip_list,
|
||||
wildcard_list=config.wildcard_list,
|
||||
port_list=config.port_list,
|
||||
protocol_list=config.protocol_list,
|
||||
)
|
||||
@@ -1,138 +0,0 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.game.agent.observations.observations import (
|
||||
AbstractObservation,
|
||||
AclObservation,
|
||||
ICSObservation,
|
||||
LinkObservation,
|
||||
NullObservation,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class UC2BlueObservation(AbstractObservation):
|
||||
"""Container for all observations used by the blue agent in UC2.
|
||||
|
||||
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
|
||||
for the purpose of compiling several observation components.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialise UC2 blue observation.
|
||||
|
||||
:param nodes: List of node observations
|
||||
:type nodes: List[NodeObservation]
|
||||
:param links: List of link observations
|
||||
:type links: List[LinkObservation]
|
||||
:param acl: The Access Control List observation
|
||||
:type acl: AclObservation
|
||||
:param ics: The ICS observation
|
||||
:type ics: ICSObservation
|
||||
:param where: Where in the simulation state dict to find information. Not used in this particular observation
|
||||
because it only compiles other observations and doesn't contribute any new information, defaults to None
|
||||
:type where: Optional[List[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: AclObservation = acl
|
||||
self.ics: ICSObservation = ics
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
|
||||
"ACL": self.acl.default_observation,
|
||||
"ICS": self.ics.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
obs["ICS"] = self.ics.observe(state)
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
links, ACL and ICS observations.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed UC2 blue observation
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
|
||||
num_services_per_node = config["num_services_per_node"]
|
||||
num_folders_per_node = config["num_folders_per_node"]
|
||||
num_files_per_folder = config["num_files_per_folder"]
|
||||
num_nics_per_node = config["num_nics_per_node"]
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
game=game,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
for n in node_configs
|
||||
]
|
||||
|
||||
link_configs = config["links"]
|
||||
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
|
||||
|
||||
acl_config = config["acl"]
|
||||
acl = AclObservation.from_config(config=acl_config, game=game)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, game=game)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
@@ -1,107 +1,130 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
"""File observation, provides status information about a file within the simulation environment."""
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
"""Observation of a file on a node in the network."""
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FileObservation."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
file_name: str
|
||||
"""Name of the file, used for querying simulation state dictionary."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to the file in the observation."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool) -> None:
|
||||
"""
|
||||
Initialise file observation.
|
||||
Initialize a file observation instance.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this file.
|
||||
A typical location for a file might be
|
||||
['network', 'nodes', <node_hostname>, 'file_system', 'folder', <folder_name>, 'files', <file_name>].
|
||||
:type where: WhereType
|
||||
:param include_num_access: Whether to include the number of accesses to the file in the observation.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
self.where: WhereType = where
|
||||
self.include_num_access: bool = include_num_access
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
self.default_observation: ObsType = {"health_status": 0}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing the health status of the file and optionally the number of accesses.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["visible_status"]}
|
||||
obs = {"health_status": file_state["visible_status"]}
|
||||
if self.include_num_access:
|
||||
obs["num_access"] = file_state["num_access"]
|
||||
# raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.")
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:return: Gymnasium space representing the observation space for file status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
space = {"health_status": spaces.Discrete(6)}
|
||||
if self.include_num_access:
|
||||
space["num_access"] = spaces.Discrete(4)
|
||||
return spaces.Dict(space)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
:type config: Dict
|
||||
:param game: _description_
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: _description_, defaults to None
|
||||
:type parent_where: _type_, optional
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation:
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config["file_name"]])
|
||||
Create a file observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the file observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this file's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed file observation instance.
|
||||
:rtype: FileObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
"""Folder observation, including files inside of the folder."""
|
||||
class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
"""Folder observation, provides status information about a folder within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FolderObservation."""
|
||||
|
||||
folder_name: str
|
||||
"""Name of the folder, used for querying simulation state dictionary."""
|
||||
files: List[FileObservation.ConfigSchema] = []
|
||||
"""List of file configurations within the folder."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of spaces for file observations in this folder."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether files in this folder should include the number of accesses in their observation."""
|
||||
|
||||
def __init__(
|
||||
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
|
||||
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
|
||||
) -> None:
|
||||
"""Initialise folder Observation, including files inside the folder.
|
||||
"""
|
||||
Initialize a folder observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param max_files: As size of the space must remain static, define max files that can be in this folder
|
||||
, defaults to 5
|
||||
:type max_files: int, optional
|
||||
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
|
||||
that even if new files are created, the existing files will always occupy the same space in the observation
|
||||
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
|
||||
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
|
||||
name, it will take the position defined in this dict. Defaults to {}
|
||||
:type file_positions: Dict[int, str], optional
|
||||
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].
|
||||
:type where: WhereType
|
||||
:param files: List of file observation instances within the folder.
|
||||
:type files: Iterable[FileObservation]
|
||||
:param num_files: Number of files expected in the folder.
|
||||
:type num_files: int
|
||||
:param include_num_access: Whether to include the number of accesses to files in the observation.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.where: WhereType = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files_per_folder:
|
||||
self.files.append(FileObservation())
|
||||
while len(self.files) > num_files_per_folder:
|
||||
while len(self.files) < num_files:
|
||||
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
|
||||
while len(self.files) > num_files:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warning(msg)
|
||||
@@ -111,16 +134,15 @@ class FolderObservation(AbstractObservation):
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the health status of the folder and status of files within the folder.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
@@ -136,9 +158,10 @@ class FolderObservation(AbstractObservation):
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:return: Gymnasium space representing the observation space for folder status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
@@ -149,29 +172,23 @@ class FolderObservation(AbstractObservation):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation:
|
||||
"""
|
||||
Create a folder observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
|
||||
folder and the files inside of it.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param config: Configuration schema containing the necessary information for the folder observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_hostname>,'file_system']
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed folder observation
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed folder observation instance.
|
||||
:rtype: FolderObservation
|
||||
"""
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
where = parent_where + ["folders", config.folder_name]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
|
||||
# pass down shared/common config items
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
|
||||
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
|
||||
files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files]
|
||||
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
|
||||
|
||||
213
src/primaite/game/agent/observations/firewall_observation.py
Normal file
213
src/primaite/game/agent/observations/firewall_observation.py
Normal file
@@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""Firewall observation, provides status information about a firewall within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FirewallObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the firewall node, used for querying simulation state dictionary."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
ip_list: List[str],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a firewall observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this firewall.
|
||||
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
|
||||
:type where: WhereType
|
||||
:param ip_list: List of IP addresses.
|
||||
:type ip_list: List[str]
|
||||
:param wildcard_list: List of wildcard rules.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param protocol_list: List of protocol types.
|
||||
:type protocol_list: List[str]
|
||||
:param num_rules: Number of rules configured in the firewall.
|
||||
:type num_rules: int
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
# TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "internal", "inbound"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.internal_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "internal", "outbound"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "dmz", "inbound"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "dmz", "outbound"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_inbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "external", "inbound"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_outbound_acl = ACLObservation(
|
||||
where=self.where + ["acl", "external", "outbound"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.internal_outbound_acl.default_observation,
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.dmz_outbound_acl.default_observation,
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.external_outbound_acl.default_observation,
|
||||
},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for firewall status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation:
|
||||
"""
|
||||
Create a firewall observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the firewall observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this firewall's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <firewall_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed firewall observation instance.
|
||||
:rtype: FirewallObservation
|
||||
"""
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
return cls(
|
||||
where=where,
|
||||
ip_list=config.ip_list,
|
||||
wildcard_list=config.wildcard_list,
|
||||
port_list=config.port_list,
|
||||
protocol_list=config.protocol_list,
|
||||
num_rules=config.num_rules,
|
||||
)
|
||||
229
src/primaite/game/agent/observations/host_observations.py
Normal file
229
src/primaite/game/agent/observations/host_observations.py
Normal file
@@ -0,0 +1,229 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"""Host observation, provides status information about a host within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for HostObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the host, used for querying simulation state dictionary."""
|
||||
services: List[ServiceObservation.ConfigSchema] = []
|
||||
"""List of services to observe on the host."""
|
||||
applications: List[ApplicationObservation.ConfigSchema] = []
|
||||
"""List of applications to observe on the host."""
|
||||
folders: List[FolderObservation.ConfigSchema] = []
|
||||
"""List of folders to observe on the host."""
|
||||
network_interfaces: List[NICObservation.ConfigSchema] = []
|
||||
"""List of network interfaces to observe on the host."""
|
||||
num_services: Optional[int] = None
|
||||
"""Number of spaces for service observations on this host."""
|
||||
num_applications: Optional[int] = None
|
||||
"""Number of spaces for application observations on this host."""
|
||||
num_folders: Optional[int] = None
|
||||
"""Number of spaces for folder observations on this host."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of spaces for file observations on this host."""
|
||||
num_nics: Optional[int] = None
|
||||
"""Number of spaces for network interface observations on this host."""
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Whether network interface observations should include number of malicious network events."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to files observations on this host."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
services: List[ServiceObservation],
|
||||
applications: List[ApplicationObservation],
|
||||
folders: List[FolderObservation],
|
||||
network_interfaces: List[NICObservation],
|
||||
num_services: int,
|
||||
num_applications: int,
|
||||
num_folders: int,
|
||||
num_files: int,
|
||||
num_nics: int,
|
||||
include_nmne: bool,
|
||||
include_num_access: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a host observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this host.
|
||||
A typical location for a host might be ['network', 'nodes', <hostname>].
|
||||
:type where: WhereType
|
||||
:param services: List of service observations on the host.
|
||||
:type services: List[ServiceObservation]
|
||||
:param applications: List of application observations on the host.
|
||||
:type applications: List[ApplicationObservation]
|
||||
:param folders: List of folder observations on the host.
|
||||
:type folders: List[FolderObservation]
|
||||
:param network_interfaces: List of network interface observations on the host.
|
||||
:type network_interfaces: List[NICObservation]
|
||||
:param num_services: Number of services to observe.
|
||||
:type num_services: int
|
||||
:param num_applications: Number of applications to observe.
|
||||
:type num_applications: int
|
||||
:param num_folders: Number of folders to observe.
|
||||
:type num_folders: int
|
||||
:param num_files: Number of files.
|
||||
:type num_files: int
|
||||
:param num_nics: Number of network interfaces.
|
||||
:type num_nics: int
|
||||
:param include_nmne: Flag to include network metrics and errors.
|
||||
:type include_nmne: bool
|
||||
:param include_num_access: Flag to include the number of accesses to files.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
# Ensure lists have lengths equal to specified counts by truncating or padding
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services:
|
||||
self.services.append(ServiceObservation(where=None))
|
||||
while len(self.services) > num_services:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.applications: List[ApplicationObservation] = applications
|
||||
while len(self.applications) < num_applications:
|
||||
self.applications.append(ApplicationObservation(where=None))
|
||||
while len(self.applications) > num_applications:
|
||||
truncated_application = self.applications.pop()
|
||||
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
while len(self.folders) < num_folders:
|
||||
self.folders.append(
|
||||
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
|
||||
)
|
||||
while len(self.folders) > num_folders:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.network_interfaces: List[NICObservation] = network_interfaces
|
||||
while len(self.network_interfaces) < num_nics:
|
||||
self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne))
|
||||
while len(self.network_interfaces) > num_nics:
|
||||
truncated_nic = self.network_interfaces.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation: ObsType = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
"num_file_creations": 0,
|
||||
"num_file_deletions": 0,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status information about the host.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {
|
||||
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
|
||||
}
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for host status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict(
|
||||
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
|
||||
),
|
||||
"num_file_creations": spaces.Discrete(4),
|
||||
"num_file_deletions": spaces.Discrete(4),
|
||||
}
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation:
|
||||
"""
|
||||
Create a host observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the host observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this host.
|
||||
A typical location might be ['network', 'nodes', <hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed host observation instance.
|
||||
:rtype: HostObservation
|
||||
"""
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", config.hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
|
||||
# Pass down shared/common config items
|
||||
for folder_config in config.folders:
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services]
|
||||
applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications]
|
||||
folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders]
|
||||
nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces]
|
||||
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
applications=applications,
|
||||
folders=folders,
|
||||
network_interfaces=nics,
|
||||
num_services=config.num_services,
|
||||
num_applications=config.num_applications,
|
||||
num_folders=config.num_folders,
|
||||
num_files=config.num_files,
|
||||
num_nics=config.num_nics,
|
||||
include_nmne=config.include_nmne,
|
||||
include_num_access=config.include_num_access,
|
||||
)
|
||||
@@ -1,188 +1,157 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
"""Status information about a network interface within the simulation environment."""
|
||||
|
||||
low_nmne_threshold: int = 0
|
||||
"""The minimum number of malicious network events to be considered low."""
|
||||
med_nmne_threshold: int = 5
|
||||
"""The minimum number of malicious network events to be considered medium."""
|
||||
high_nmne_threshold: int = 10
|
||||
"""The minimum number of malicious network events to be considered high."""
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NICObservation."""
|
||||
|
||||
global CAPTURE_NMNE
|
||||
nic_num: int
|
||||
"""Number of the network interface."""
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Whether to include number of malicious network events (NMNE) in the observation."""
|
||||
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
data = {"nic_status": 0}
|
||||
if CAPTURE_NMNE:
|
||||
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
|
||||
|
||||
return data
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
low_nmne_threshold: Optional[int] = 0,
|
||||
med_nmne_threshold: Optional[int] = 5,
|
||||
high_nmne_threshold: Optional[int] = 10,
|
||||
) -> None:
|
||||
"""Initialise NIC observation.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_hostname>,'NICs',<nic_number>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
def __init__(self, where: WhereType, include_nmne: bool) -> None:
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialize a network interface observation instance.
|
||||
|
||||
global CAPTURE_NMNE
|
||||
if CAPTURE_NMNE:
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
|
||||
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
|
||||
self._validate_nmne_categories(
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
|
||||
def _validate_nmne_categories(
|
||||
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
|
||||
):
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this interface.
|
||||
A typical location for a network interface might be
|
||||
['network', 'nodes', <node_hostname>, 'NICs', <nic_num>].
|
||||
:type where: WhereType
|
||||
:param include_nmne: Flag to determine whether to include NMNE information in the observation.
|
||||
:type include_nmne: bool
|
||||
"""
|
||||
Validates the nmne threshold config.
|
||||
self.where = where
|
||||
self.include_nmne: bool = include_nmne
|
||||
|
||||
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
|
||||
self.default_observation: ObsType = {"nic_status": 0}
|
||||
if self.include_nmne:
|
||||
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
|
||||
|
||||
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
|
||||
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
|
||||
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
if high_nmne_threshold <= med_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
|
||||
f"than medium nmne count ({med_nmne_threshold})"
|
||||
)
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
if med_nmne_threshold <= low_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
|
||||
f"than low nmne count ({low_nmne_threshold})"
|
||||
)
|
||||
|
||||
self.high_nmne_threshold = high_nmne_threshold
|
||||
self.med_nmne_threshold = med_nmne_threshold
|
||||
self.low_nmne_threshold = low_nmne_threshold
|
||||
|
||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||
"""
|
||||
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
|
||||
|
||||
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
|
||||
|
||||
Bins are defined as follows:
|
||||
- 0: No MNEs detected (0 events).
|
||||
- 1: Low number of MNEs (default 1-5 events).
|
||||
- 2: Moderate number of MNEs (default 6-10 events).
|
||||
- 3: High number of MNEs (default more than 10 events).
|
||||
|
||||
:param nmne_count: Number of MNEs detected.
|
||||
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
|
||||
"""
|
||||
if nmne_count > self.high_nmne_threshold:
|
||||
return 3
|
||||
elif nmne_count > self.med_nmne_threshold:
|
||||
return 2
|
||||
elif nmne_count > self.low_nmne_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing the status of the network interface and optionally NMNE information.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if CAPTURE_NMNE:
|
||||
obs_dict.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs_dict
|
||||
|
||||
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if self.include_nmne:
|
||||
obs.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for network interface status and NMNE information.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
if CAPTURE_NMNE:
|
||||
if self.include_nmne:
|
||||
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
|
||||
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation:
|
||||
"""
|
||||
low_nmne_threshold = None
|
||||
med_nmne_threshold = None
|
||||
high_nmne_threshold = None
|
||||
Create a network interface observation from a configuration schema.
|
||||
|
||||
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
|
||||
threshold = game.options.thresholds["nmne"]
|
||||
:param config: Configuration schema containing the necessary information for the network interface observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed network interface observation instance.
|
||||
:rtype: NICObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne)
|
||||
|
||||
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
|
||||
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
|
||||
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
|
||||
|
||||
return cls(
|
||||
where=parent_where + ["NICs", config["nic_num"]],
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
"""Port observation, provides status information about a network port within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for PortObservation."""
|
||||
|
||||
port_id: int
|
||||
"""Identifier of the port, used for querying simulation state dictionary."""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialize a port observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this port.
|
||||
A typical location for a port might be ['network', 'nodes', <node_hostname>, 'NICs', <port_id>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"operating_status": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the operating status of the port.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
port_state = access_from_nested_dict(state, self.where)
|
||||
if port_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"operating_status": 1 if port_state["enabled"] else 2}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for port status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation:
|
||||
"""
|
||||
Create a port observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the port observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this port's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed port observation instance.
|
||||
:rtype: PortObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config.port_id])
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,24 +1,23 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Type
|
||||
from typing import Any, Dict, Iterable, Type
|
||||
|
||||
from gymnasium import spaces
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
WhereType = Iterable[str | int] | None
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
"""Abstract class for an observation space component."""
|
||||
|
||||
class ConfigSchema(ABC, BaseModel):
|
||||
"""Config schema for observations."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
_registry: Dict[str, Type["AbstractObservation"]] = {}
|
||||
@@ -61,269 +60,271 @@ class AbstractObservation(ABC):
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
ObservationType = cls._registry[cfg['type']]
|
||||
ObservationType = cls._registry[cfg["type"]]
|
||||
return ObservationType.from_config(cfg=cfg)
|
||||
|
||||
|
||||
# class LinkObservation(AbstractObservation):
|
||||
# """Observation of a link in the network."""
|
||||
'''
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
# default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
# "Default observation is what should be returned when the link doesn't exist."
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
# def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
# """Initialise link observation.
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
|
||||
# :param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
# Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
# zeroes.
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
# A typical location for a service looks like this:
|
||||
# `['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
# :type where: Optional[List[str]]
|
||||
# """
|
||||
# super().__init__()
|
||||
# self.where: Optional[Tuple[str]] = where
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
# def observe(self, state: Dict) -> Dict:
|
||||
# """Generate observation based on the current state of the simulation.
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
# :param state: Simulation state dictionary
|
||||
# :type state: Dict
|
||||
# :return: Observation
|
||||
# :rtype: Dict
|
||||
# """
|
||||
# if self.where is None:
|
||||
# return self.default_observation
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
# link_state = access_from_nested_dict(state, self.where)
|
||||
# if link_state is NOT_PRESENT_IN_STATE:
|
||||
# return self.default_observation
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# bandwidth = link_state["bandwidth"]
|
||||
# load = link_state["current_load"]
|
||||
# if load == 0:
|
||||
# utilisation_category = 0
|
||||
# else:
|
||||
# utilisation_fraction = load / bandwidth
|
||||
# # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
# utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# # TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
# return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
# @property
|
||||
# def space(self) -> spaces.Space:
|
||||
# """Gymnasium space object describing the observation space shape.
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
# :return: Gymnasium space
|
||||
# :rtype: spaces.Space
|
||||
# """
|
||||
# return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
# @classmethod
|
||||
# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
# """Create link observation from a config.
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
# :param config: Dictionary containing the configuration for this link observation.
|
||||
# :type config: Dict
|
||||
# :param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
# :type game: PrimaiteGame
|
||||
# :return: Constructed link observation
|
||||
# :rtype: LinkObservation
|
||||
# """
|
||||
# return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
# class AclObservation(AbstractObservation):
|
||||
# """Observation of an Access Control List (ACL) in the network."""
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# # TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# # definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# # if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# # this needs adding, but not for the MVP.
|
||||
# def __init__(
|
||||
# self,
|
||||
# node_ip_to_id: Dict[str, int],
|
||||
# ports: List[int],
|
||||
# protocols: List[str],
|
||||
# where: Optional[Tuple[str]] = None,
|
||||
# num_rules: int = 10,
|
||||
# ) -> None:
|
||||
# """Initialise ACL observation.
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: List[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
|
||||
# :param node_ip_to_id: Mapping between IP address and ID.
|
||||
# :type node_ip_to_id: Dict[str, int]
|
||||
# :param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
# :type ports: List[int]
|
||||
# :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
# :type protocols: list[str]
|
||||
# :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
# example may look like this:
|
||||
# ['network','nodes',<router_hostname>,'acl','acl']
|
||||
# :type where: Optional[Tuple[str]], optional
|
||||
# :param num_rules: , defaults to 10
|
||||
# :type num_rules: int, optional
|
||||
# """
|
||||
# super().__init__()
|
||||
# self.where: Optional[Tuple[str]] = where
|
||||
# self.num_rules: int = num_rules
|
||||
# self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
# "List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
# self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
# "List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
# self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
# "List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
# self.default_observation: Dict = {
|
||||
# i
|
||||
# + 1: {
|
||||
# "position": i,
|
||||
# "permission": 0,
|
||||
# "source_node_id": 0,
|
||||
# "source_port": 0,
|
||||
# "dest_node_id": 0,
|
||||
# "dest_port": 0,
|
||||
# "protocol": 0,
|
||||
# }
|
||||
# for i in range(self.num_rules)
|
||||
# }
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_hostname>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
# def observe(self, state: Dict) -> Dict:
|
||||
# """Generate observation based on the current state of the simulation.
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
# :param state: Simulation state dictionary
|
||||
# :type state: Dict
|
||||
# :return: Observation
|
||||
# :rtype: Dict
|
||||
# """
|
||||
# if self.where is None:
|
||||
# return self.default_observation
|
||||
# acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
# if acl_state is NOT_PRESENT_IN_STATE:
|
||||
# return self.default_observation
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# # TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
# obs = {}
|
||||
# acl_items = dict(acl_state.items())
|
||||
# i = 1 # don't show rule 0 for compatibility reasons.
|
||||
# while i < self.num_rules + 1:
|
||||
# rule_state = acl_items[i]
|
||||
# if rule_state is None:
|
||||
# obs[i] = {
|
||||
# "position": i - 1,
|
||||
# "permission": 0,
|
||||
# "source_node_id": 0,
|
||||
# "source_port": 0,
|
||||
# "dest_node_id": 0,
|
||||
# "dest_port": 0,
|
||||
# "protocol": 0,
|
||||
# }
|
||||
# else:
|
||||
# src_ip = rule_state["src_ip_address"]
|
||||
# src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
# dst_ip = rule_state["dst_ip_address"]
|
||||
# dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
# src_port = rule_state["src_port"]
|
||||
# src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
# dst_port = rule_state["dst_port"]
|
||||
# dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
# protocol = rule_state["protocol"]
|
||||
# protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
# obs[i] = {
|
||||
# "position": i - 1,
|
||||
# "permission": rule_state["action"],
|
||||
# "source_node_id": src_node_id,
|
||||
# "source_port": src_port_id,
|
||||
# "dest_node_id": dst_node_ip,
|
||||
# "dest_port": dst_port_id,
|
||||
# "protocol": protocol_id,
|
||||
# }
|
||||
# i += 1
|
||||
# return obs
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
# @property
|
||||
# def space(self) -> spaces.Space:
|
||||
# """Gymnasium space object describing the observation space shape.
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
# :return: Gymnasium space
|
||||
# :rtype: spaces.Space
|
||||
# """
|
||||
# return spaces.Dict(
|
||||
# {
|
||||
# i
|
||||
# + 1: spaces.Dict(
|
||||
# {
|
||||
# "position": spaces.Discrete(self.num_rules),
|
||||
# "permission": spaces.Discrete(3),
|
||||
# # adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
# "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
# "source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
# "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
# "dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
# "protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
# }
|
||||
# )
|
||||
# for i in range(self.num_rules)
|
||||
# }
|
||||
# )
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
# @classmethod
|
||||
# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
# """Generate ACL observation from a config.
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
# :param config: Dictionary containing the configuration for this ACL observation.
|
||||
# :type config: Dict
|
||||
# :param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
# :type game: PrimaiteGame
|
||||
# :return: Observation object
|
||||
# :rtype: AclObservation
|
||||
# """
|
||||
# max_acl_rules = config["options"]["max_acl_rules"]
|
||||
# node_ip_to_idx = {}
|
||||
# for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
# node_ref = ip_map_config["node_hostname"]
|
||||
# nic_num = ip_map_config["nic_num"]
|
||||
# node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
# nic_obj = node_obj.network_interface[nic_num]
|
||||
# node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_hostname"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.network_interface[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
# router_hostname = config["router_hostname"]
|
||||
# return cls(
|
||||
# node_ip_to_id=node_ip_to_idx,
|
||||
# ports=game.options.ports,
|
||||
# protocols=game.options.protocols,
|
||||
# where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
# num_rules=max_acl_rules,
|
||||
# )
|
||||
router_hostname = config["router_hostname"]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
# class NullObservation(AbstractObservation):
|
||||
# """Null observation, returns a single 0 value for the observation space."""
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
# def __init__(self, where: Optional[List[str]] = None):
|
||||
# """Initialise null observation."""
|
||||
# self.default_observation: Dict = {}
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
# def observe(self, state: Dict) -> Dict:
|
||||
# """Generate observation based on the current state of the simulation."""
|
||||
# return 0
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
|
||||
# @property
|
||||
# def space(self) -> spaces.Space:
|
||||
# """Gymnasium space object describing the observation space shape."""
|
||||
# return spaces.Discrete(1)
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
# @classmethod
|
||||
# def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
# """
|
||||
# Create null observation from a config.
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
# The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
# """
|
||||
# return cls()
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
# class ICSObservation(NullObservation):
|
||||
# """ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
# pass
|
||||
pass
|
||||
'''
|
||||
|
||||
142
src/primaite/game/agent/observations/router_observation.py
Normal file
142
src/primaite/game/agent/observations/router_observation.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
"""Router observation, provides status information about a router within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for RouterObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the router, used for querying simulation state dictionary."""
|
||||
ports: Optional[List[PortObservation.ConfigSchema]] = None
|
||||
"""Configuration of port observations for this router."""
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of port observations configured for this router."""
|
||||
acl: Optional[ACLObservation.ConfigSchema] = None
|
||||
"""Configuration of ACL observation on this router."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
ports: List[PortObservation],
|
||||
num_ports: int,
|
||||
acl: ACLObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a router observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this router.
|
||||
A typical location for a router might be ['network', 'nodes', <node_hostname>].
|
||||
:type where: WhereType
|
||||
:param ports: List of port observations representing the ports of the router.
|
||||
:type ports: List[PortObservation]
|
||||
:param num_ports: Number of ports for the router.
|
||||
:type num_ports: int
|
||||
:param acl: ACL observation representing the access control list of the router.
|
||||
:type acl: ACLObservation
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.ports: List[PortObservation] = ports
|
||||
self.acl: ACLObservation = acl
|
||||
self.num_ports: int = num_ports
|
||||
|
||||
while len(self.ports) < num_ports:
|
||||
self.ports.append(PortObservation(where=None))
|
||||
while len(self.ports) > num_ports:
|
||||
self.ports.pop()
|
||||
msg = "Too many ports in router observation. Truncating."
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
|
||||
"ACL": self.acl.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status of ports and ACL configuration of the router.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
router_state = access_from_nested_dict(state, self.where)
|
||||
if router_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for router status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation:
|
||||
"""
|
||||
Create a router observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the router observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this router's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed router observation instance.
|
||||
:rtype: RouterObservation
|
||||
"""
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
|
||||
if config.acl is None:
|
||||
config.acl = ACLObservation.ConfigSchema()
|
||||
if config.acl.num_rules is None:
|
||||
config.acl.num_rules = config.num_rules
|
||||
if config.acl.ip_list is None:
|
||||
config.acl.ip_list = config.ip_list
|
||||
if config.acl.wildcard_list is None:
|
||||
config.acl.wildcard_list = config.wildcard_list
|
||||
if config.acl.port_list is None:
|
||||
config.acl.port_list = config.port_list
|
||||
if config.acl.protocol_list is None:
|
||||
config.acl.protocol_list = config.protocol_list
|
||||
|
||||
if config.ports is None:
|
||||
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
|
||||
|
||||
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
|
||||
@@ -1,45 +1,43 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
"""Service observation, shows status of a service in the simulation environment."""
|
||||
|
||||
class ServiceObservation(AbstractObservation):
|
||||
"""Observation of a service in the network."""
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ServiceObservation."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
service_name: str
|
||||
"""Name of the service, used for querying simulation state dictionary"""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise service observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'services', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialize a service observation instance.
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this service.
|
||||
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0}
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing the operating status and health status of the service.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
@@ -50,114 +48,96 @@ class ServiceObservation(AbstractObservation):
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for service status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation:
|
||||
"""
|
||||
Create a service observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:param config: Configuration schema containing the necessary information for the service observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this service's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed service observation instance.
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config["service_name"]])
|
||||
return cls(where=parent_where + ["services", config.service_name])
|
||||
|
||||
|
||||
class ApplicationObservation(AbstractObservation):
|
||||
"""Observation of an application in the network."""
|
||||
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
"""Application observation, shows the status of an application within the simulation environment."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
"Default observation is what should be returned when the application doesn't exist."
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ApplicationObservation."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise application observation.
|
||||
application_name: str
|
||||
"""Name of the application, used for querying simulation state dictionary"""
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'applications', <application_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialise an application observation instance.
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this application.
|
||||
A typical location for an application might be
|
||||
['network', 'nodes', <node_hostname>, 'applications', <application_name>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Obs containing the operating status, health status, and number of executions of the application.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
app_state = access_from_nested_dict(state, self.where)
|
||||
if app_state is NOT_PRESENT_IN_STATE:
|
||||
application_state = access_from_nested_dict(state, self.where)
|
||||
if application_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": app_state["operating_state"],
|
||||
"health_status": app_state["health_state_visible"],
|
||||
"num_executions": self._categorise_num_executions(app_state["num_executions"]),
|
||||
"operating_status": application_state["operating_state"],
|
||||
"health_status": application_state["health_state_visible"],
|
||||
"num_executions": application_state["num_executions"],
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for application status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"operating_status": spaces.Discrete(7),
|
||||
"health_status": spaces.Discrete(6),
|
||||
"health_status": spaces.Discrete(5),
|
||||
"num_executions": spaces.Discrete(4),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ApplicationObservation":
|
||||
"""Create application observation from a config.
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation:
|
||||
"""
|
||||
Create an application observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:param config: Configuration schema containing the necessary information for the application observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this application's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed application observation instance.
|
||||
:rtype: ApplicationObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config["application_name"]])
|
||||
|
||||
@classmethod
|
||||
def _categorise_num_executions(cls, num_executions: int) -> int:
|
||||
"""
|
||||
Categorise the number of executions of an application.
|
||||
|
||||
Helps classify the number of application executions into different categories.
|
||||
|
||||
Current categories:
|
||||
- 0: Application is never executed
|
||||
- 1: Application is executed a low number of times (1-5)
|
||||
- 2: Application is executed often (6-10)
|
||||
- 3: Application is executed a high number of times (more than 10)
|
||||
|
||||
:param: num_executions: Number of times the application is executed
|
||||
"""
|
||||
if num_executions > 10:
|
||||
return 3
|
||||
elif num_executions > 5:
|
||||
return 2
|
||||
elif num_executions > 0:
|
||||
return 1
|
||||
return 0
|
||||
return cls(where=parent_where + ["applications", config.application_name])
|
||||
|
||||
Reference in New Issue
Block a user