diff --git a/src/primaite/game/agent/observations/acl_observation.py b/src/primaite/game/agent/observations/acl_observation.py new file mode 100644 index 00000000..2d29223d --- /dev/null +++ b/src/primaite/game/agent/observations/acl_observation.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +from ipaddress import IPv4Address +from typing import Dict, List, Optional + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + + +class ACLObservation(AbstractObservation, identifier="ACL"): + """ACL observation, provides information about access control lists within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ACLObservation.""" + + ip_list: Optional[List[IPv4Address]] = None + """List of IP addresses.""" + wildcard_list: Optional[List[str]] = None + """List of wildcard strings.""" + port_list: Optional[List[int]] = None + """List of port numbers.""" + protocol_list: Optional[List[str]] = None + """List of protocol names.""" + num_rules: Optional[int] = None + """Number of ACL rules.""" + + def __init__( + self, + where: WhereType, + num_rules: int, + ip_list: List[IPv4Address], + wildcard_list: List[str], + port_list: List[int], + protocol_list: List[str], + ) -> None: + """ + Initialize an ACL observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this ACL. + :type where: WhereType + :param num_rules: Number of ACL rules. + :type num_rules: int + :param ip_list: List of IP addresses. + :type ip_list: List[IPv4Address] + :param wildcard_list: List of wildcard strings. + :type wildcard_list: List[str] + :param port_list: List of port numbers. + :type port_list: List[int] + :param protocol_list: List of protocol names. + :type protocol_list: List[str] + """ + self.where = where + self.num_rules: int = num_rules + self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)} + self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)} + self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)} + self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)} + self.default_observation: Dict = { + i + + 1: { + "position": i, + "permission": 0, + "source_ip_id": 0, + "source_wildcard_id": 0, + "source_port_id": 0, + "dest_ip_id": 0, + "dest_wildcard_id": 0, + "dest_port_id": 0, + "protocol_id": 0, + } + for i in range(self.num_rules) + } + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing ACL rules. + :rtype: ObsType + """ + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation + obs = {} + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] + if rule_state is None: + obs[i] = { + "position": i - 1, + "permission": 0, + "source_ip_id": 0, + "source_wildcard_id": 0, + "source_port_id": 0, + "dest_ip_id": 0, + "dest_wildcard_id": 0, + "dest_port_id": 0, + "protocol_id": 0, + } + else: + src_ip = rule_state["src_ip_address"] + src_node_id = self.ip_to_id.get(src_ip, 1) + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = self.ip_to_id.get(dst_ip, 1) + src_wildcard = rule_state["source_wildcard_id"] + src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1) + dst_wildcard = rule_state["dest_wildcard_id"] + dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1) + src_port = rule_state["source_port_id"] + src_port_id = self.port_to_id.get(src_port, 1) + dst_port = rule_state["dest_port_id"] + dst_port_id = self.port_to_id.get(dst_port, 1) + protocol = rule_state["protocol"] + protocol_id = self.protocol_to_id.get(protocol, 1) + obs[i] = { + "position": i - 1, + "permission": rule_state["action"], + "source_ip_id": src_node_id, + "source_wildcard_id": src_wildcard_id, + "source_port_id": src_port_id, + "dest_ip_id": dst_node_ip, + "dest_wildcard_id": dst_wildcard_id, + "dest_port_id": dst_port_id, + "protocol_id": protocol_id, + } + i += 1 + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for ACL rules. + :rtype: spaces.Space + """ + return spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), + "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), + "source_port_id": spaces.Discrete(len(self.port_to_id) + 2), + "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), + "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), + "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), + "protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2), + } + ) + for i in range(self.num_rules) + } + ) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation: + """ + Create an ACL observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the ACL observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this ACL's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed ACL observation instance. + :rtype: ACLObservation + """ + return cls( + where=parent_where + ["acl", "acl"], + num_rules=config.num_rules, + ip_list=config.ip_list, + wildcard_list=config.wildcard_list, + port_list=config.port_list, + protocol_list=config.protocol_list, + ) diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py deleted file mode 100644 index 10370660..00000000 --- a/src/primaite/game/agent/observations/agent_observations.py +++ /dev/null @@ -1,138 +0,0 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING - -from gymnasium import spaces - -from primaite.game.agent.observations.node_observations import NodeObservation -from primaite.game.agent.observations.observations import ( - AbstractObservation, - AclObservation, - ICSObservation, - LinkObservation, - NullObservation, -) - -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - - -class UC2BlueObservation(AbstractObservation): - """Container for all observations used by the blue agent in UC2. - - TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler - for the purpose of compiling several observation components. - """ - - def __init__( - self, - nodes: List[NodeObservation], - links: List[LinkObservation], - acl: AclObservation, - ics: ICSObservation, - where: Optional[List[str]] = None, - ) -> None: - """Initialise UC2 blue observation. - - :param nodes: List of node observations - :type nodes: List[NodeObservation] - :param links: List of link observations - :type links: List[LinkObservation] - :param acl: The Access Control List observation - :type acl: AclObservation - :param ics: The ICS observation - :type ics: ICSObservation - :param where: Where in the simulation state dict to find information. Not used in this particular observation - because it only compiles other observations and doesn't contribute any new information, defaults to None - :type where: Optional[List[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.nodes: List[NodeObservation] = nodes - self.links: List[LinkObservation] = links - self.acl: AclObservation = acl - self.ics: ICSObservation = ics - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, - "ACL": self.acl.default_observation, - "ICS": self.ics.default_observation, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} - obs["ACL"] = self.acl.observe(state) - obs["ICS"] = self.ics.observe(state) - - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), - "ACL": self.acl.space, - "ICS": self.ics.space, - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": - """Create UC2 blue observation from a config. - - :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, - links, ACL and ICS observations. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed UC2 blue observation - :rtype: UC2BlueObservation - """ - node_configs = config["nodes"] - - num_services_per_node = config["num_services_per_node"] - num_folders_per_node = config["num_folders_per_node"] - num_files_per_folder = config["num_files_per_folder"] - num_nics_per_node = config["num_nics_per_node"] - nodes = [ - NodeObservation.from_config( - config=n, - game=game, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - for n in node_configs - ] - - link_configs = config["links"] - links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] - - acl_config = config["acl"] - acl = AclObservation.from_config(config=acl_config, game=game) - - ics_config = config["ics"] - ics = ICSObservation.from_config(config=ics_config, game=game) - new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) - return new - diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py index 277bc51f..a30bfc82 100644 --- a/src/primaite/game/agent/observations/file_system_observations.py +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -1,107 +1,130 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from __future__ import annotations + +from typing import Dict, Iterable, List, Optional from gymnasium import spaces +from gymnasium.core import ObsType from primaite import getLogger -from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame +class FileObservation(AbstractObservation, identifier="FILE"): + """File observation, provides status information about a file within the simulation environment.""" -class FileObservation(AbstractObservation): - """Observation of a file on a node in the network.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FileObservation.""" - def __init__(self, where: Optional[Tuple[str]] = None) -> None: + file_name: str + """Name of the file, used for querying simulation state dictionary.""" + include_num_access: Optional[bool] = None + """Whether to include the number of accesses to the file in the observation.""" + + def __init__(self, where: WhereType, include_num_access: bool) -> None: """ - Initialise file observation. + Initialize a file observation instance. - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',,'files',] - :type where: Optional[List[str]] + :param where: Where in the simulation state dictionary to find the relevant information for this file. + A typical location for a file might be + ['network', 'nodes', , 'file_system', 'folder', , 'files', ]. + :type where: WhereType + :param include_num_access: Whether to include the number of accesses to the file in the observation. + :type include_num_access: bool """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.default_observation: spaces.Space = {"health_status": 0} - "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." + self.where: WhereType = where + self.include_num_access: bool = include_num_access - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. + self.default_observation: ObsType = {"health_status": 0} + if self.include_num_access: + self.default_observation["num_access"] = 0 - :param state: Simulation state dictionary + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. :type state: Dict - :return: Observation - :rtype: Dict + :return: Observation containing the health status of the file and optionally the number of accesses. + :rtype: ObsType """ - if self.where is None: - return self.default_observation file_state = access_from_nested_dict(state, self.where) if file_state is NOT_PRESENT_IN_STATE: return self.default_observation - return {"health_status": file_state["visible_status"]} + obs = {"health_status": file_state["visible_status"]} + if self.include_num_access: + obs["num_access"] = file_state["num_access"] + # raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.") + return obs @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. + """ + Gymnasium space object describing the observation space shape. - :return: Gymnasium space + :return: Gymnasium space representing the observation space for file status. :rtype: spaces.Space """ - return spaces.Dict({"health_status": spaces.Discrete(6)}) + space = {"health_status": spaces.Discrete(6)} + if self.include_num_access: + space["num_access"] = spaces.Discrete(4) + return spaces.Dict(space) @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": - """Create file observation from a config. - - :param config: Dictionary containing the configuration for this file observation. - :type config: Dict - :param game: _description_ - :type game: PrimaiteGame - :param parent_where: _description_, defaults to None - :type parent_where: _type_, optional - :return: _description_ - :rtype: _type_ + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation: """ - return cls(where=parent_where + ["files", config["file_name"]]) + Create a file observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the file observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this file's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed file observation instance. + :rtype: FileObservation + """ + return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access) -class FolderObservation(AbstractObservation): - """Folder observation, including files inside of the folder.""" +class FolderObservation(AbstractObservation, identifier="FOLDER"): + """Folder observation, provides status information about a folder within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FolderObservation.""" + + folder_name: str + """Name of the folder, used for querying simulation state dictionary.""" + files: List[FileObservation.ConfigSchema] = [] + """List of file configurations within the folder.""" + num_files: Optional[int] = None + """Number of spaces for file observations in this folder.""" + include_num_access: Optional[bool] = None + """Whether files in this folder should include the number of accesses in their observation.""" def __init__( - self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 + self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool ) -> None: - """Initialise folder Observation, including files inside the folder. + """ + Initialize a folder observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this folder. - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',] - :type where: Optional[List[str]] - :param max_files: As size of the space must remain static, define max files that can be in this folder - , defaults to 5 - :type max_files: int, optional - :param file_positions: Defines the positioning within the observation space of particular files. This ensures - that even if new files are created, the existing files will always occupy the same space in the observation - space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the - observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same - name, it will take the position defined in this dict. Defaults to {} - :type file_positions: Dict[int, str], optional + A typical location for a folder might be ['network', 'nodes', , 'folders', ]. + :type where: WhereType + :param files: List of file observation instances within the folder. + :type files: Iterable[FileObservation] + :param num_files: Number of files expected in the folder. + :type num_files: int + :param include_num_access: Whether to include the number of accesses to files in the observation. + :type include_num_access: bool """ - super().__init__() - - self.where: Optional[Tuple[str]] = where + self.where: WhereType = where self.files: List[FileObservation] = files - while len(self.files) < num_files_per_folder: - self.files.append(FileObservation()) - while len(self.files) > num_files_per_folder: + while len(self.files) < num_files: + self.files.append(FileObservation(where=None, include_num_access=include_num_access)) + while len(self.files) > num_files: truncated_file = self.files.pop() msg = f"Too many files in folder observation. Truncating file {truncated_file}" _LOGGER.warning(msg) @@ -111,16 +134,15 @@ class FolderObservation(AbstractObservation): "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, } - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the health status of the folder and status of files within the folder. + :rtype: ObsType """ - if self.where is None: - return self.default_observation folder_state = access_from_nested_dict(state, self.where) if folder_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -136,9 +158,10 @@ class FolderObservation(AbstractObservation): @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. + """ + Gymnasium space object describing the observation space shape. - :return: Gymnasium space + :return: Gymnasium space representing the observation space for folder status. :rtype: spaces.Space """ return spaces.Dict( @@ -149,29 +172,23 @@ class FolderObservation(AbstractObservation): ) @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 - ) -> "FolderObservation": - """Create folder observation from a config. Also creates child file observations. + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation: + """ + Create a folder observation from a configuration schema. - :param config: Dictionary containing the configuration for this folder observation. Includes the name of the - folder and the files inside of it. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame + :param config: Configuration schema containing the necessary information for the folder observation. + :type config: ConfigSchema :param parent_where: Where in the simulation state dictionary to find the information about this folder's - parent node. A typical location for a node ``where`` can be: - ['network','nodes',,'file_system'] - :type parent_where: Optional[List[str]] - :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed folder observation + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed folder observation instance. :rtype: FolderObservation """ - where = parent_where + ["folders", config["folder_name"]] + where = parent_where + ["folders", config.folder_name] - file_configs = config["files"] - files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] + # pass down shared/common config items + for file_config in config.files: + file_config.include_num_access = config.include_num_access - return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) + files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] + return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) diff --git a/src/primaite/game/agent/observations/firewall_observation.py b/src/primaite/game/agent/observations/firewall_observation.py new file mode 100644 index 00000000..6397d473 --- /dev/null +++ b/src/primaite/game/agent/observations/firewall_observation.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from typing import Dict, List, Optional + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.acl_observation import ACLObservation +from primaite.game.agent.observations.nic_observations import PortObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType + +_LOGGER = getLogger(__name__) + + +class FirewallObservation(AbstractObservation, identifier="FIREWALL"): + """Firewall observation, provides status information about a firewall within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for FirewallObservation.""" + + hostname: str + """Hostname of the firewall node, used for querying simulation state dictionary.""" + ip_list: Optional[List[str]] = None + """List of IP addresses for encoding ACLs.""" + wildcard_list: Optional[List[str]] = None + """List of IP wildcards for encoding ACLs.""" + port_list: Optional[List[int]] = None + """List of ports for encoding ACLs.""" + protocol_list: Optional[List[str]] = None + """List of protocols for encoding ACLs.""" + num_rules: Optional[int] = None + """Number of rules ACL rules to show.""" + + def __init__( + self, + where: WhereType, + ip_list: List[str], + wildcard_list: List[str], + port_list: List[int], + protocol_list: List[str], + num_rules: int, + ) -> None: + """ + Initialize a firewall observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this firewall. + A typical location for a firewall might be ['network', 'nodes', ]. + :type where: WhereType + :param ip_list: List of IP addresses. + :type ip_list: List[str] + :param wildcard_list: List of wildcard rules. + :type wildcard_list: List[str] + :param port_list: List of port numbers. + :type port_list: List[int] + :param protocol_list: List of protocol types. + :type protocol_list: List[str] + :param num_rules: Number of rules configured in the firewall. + :type num_rules: int + """ + self.where: WhereType = where + + self.ports: List[PortObservation] = [ + PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3) + ] + # TODO: check what the port nums are for firewall. + + self.internal_inbound_acl = ACLObservation( + where=self.where + ["acl", "internal", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.internal_outbound_acl = ACLObservation( + where=self.where + ["acl", "internal", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.dmz_inbound_acl = ACLObservation( + where=self.where + ["acl", "dmz", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.dmz_outbound_acl = ACLObservation( + where=self.where + ["acl", "dmz", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.external_inbound_acl = ACLObservation( + where=self.where + ["acl", "external", "inbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + self.external_outbound_acl = ACLObservation( + where=self.where + ["acl", "external", "outbound"], + num_rules=num_rules, + ip_list=ip_list, + wildcard_list=wildcard_list, + port_list=port_list, + protocol_list=protocol_list, + ) + + self.default_observation = { + "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.default_observation, + "OUTBOUND": self.internal_outbound_acl.default_observation, + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.default_observation, + "OUTBOUND": self.dmz_outbound_acl.default_observation, + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.default_observation, + "OUTBOUND": self.external_outbound_acl.default_observation, + }, + } + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. + :rtype: ObsType + """ + obs = { + "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.observe(state), + "OUTBOUND": self.internal_outbound_acl.observe(state), + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.observe(state), + "OUTBOUND": self.dmz_outbound_acl.observe(state), + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.observe(state), + "OUTBOUND": self.external_outbound_acl.observe(state), + }, + } + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for firewall status. + :rtype: spaces.Space + """ + space = spaces.Dict( + { + "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), + "INTERNAL": spaces.Dict( + { + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + } + ), + "DMZ": spaces.Dict( + { + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + } + ), + "EXTERNAL": spaces.Dict( + { + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + } + ), + } + ) + return space + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: + """ + Create a firewall observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the firewall observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this firewall's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed firewall observation instance. + :rtype: FirewallObservation + """ + where = parent_where + ["nodes", config.hostname] + return cls( + where=where, + ip_list=config.ip_list, + wildcard_list=config.wildcard_list, + port_list=config.port_list, + protocol_list=config.protocol_list, + num_rules=config.num_rules, + ) diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py new file mode 100644 index 00000000..34c9b3ff --- /dev/null +++ b/src/primaite/game/agent/observations/host_observations.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +from typing import Dict, List, Optional + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.file_system_observations import FolderObservation +from primaite.game.agent.observations.nic_observations import NICObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + + +class HostObservation(AbstractObservation, identifier="HOST"): + """Host observation, provides status information about a host within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for HostObservation.""" + + hostname: str + """Hostname of the host, used for querying simulation state dictionary.""" + services: List[ServiceObservation.ConfigSchema] = [] + """List of services to observe on the host.""" + applications: List[ApplicationObservation.ConfigSchema] = [] + """List of applications to observe on the host.""" + folders: List[FolderObservation.ConfigSchema] = [] + """List of folders to observe on the host.""" + network_interfaces: List[NICObservation.ConfigSchema] = [] + """List of network interfaces to observe on the host.""" + num_services: Optional[int] = None + """Number of spaces for service observations on this host.""" + num_applications: Optional[int] = None + """Number of spaces for application observations on this host.""" + num_folders: Optional[int] = None + """Number of spaces for folder observations on this host.""" + num_files: Optional[int] = None + """Number of spaces for file observations on this host.""" + num_nics: Optional[int] = None + """Number of spaces for network interface observations on this host.""" + include_nmne: Optional[bool] = None + """Whether network interface observations should include number of malicious network events.""" + include_num_access: Optional[bool] = None + """Whether to include the number of accesses to files observations on this host.""" + + def __init__( + self, + where: WhereType, + services: List[ServiceObservation], + applications: List[ApplicationObservation], + folders: List[FolderObservation], + network_interfaces: List[NICObservation], + num_services: int, + num_applications: int, + num_folders: int, + num_files: int, + num_nics: int, + include_nmne: bool, + include_num_access: bool, + ) -> None: + """ + Initialize a host observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this host. + A typical location for a host might be ['network', 'nodes', ]. + :type where: WhereType + :param services: List of service observations on the host. + :type services: List[ServiceObservation] + :param applications: List of application observations on the host. + :type applications: List[ApplicationObservation] + :param folders: List of folder observations on the host. + :type folders: List[FolderObservation] + :param network_interfaces: List of network interface observations on the host. + :type network_interfaces: List[NICObservation] + :param num_services: Number of services to observe. + :type num_services: int + :param num_applications: Number of applications to observe. + :type num_applications: int + :param num_folders: Number of folders to observe. + :type num_folders: int + :param num_files: Number of files. + :type num_files: int + :param num_nics: Number of network interfaces. + :type num_nics: int + :param include_nmne: Flag to include network metrics and errors. + :type include_nmne: bool + :param include_num_access: Flag to include the number of accesses to files. + :type include_num_access: bool + """ + self.where: WhereType = where + + # Ensure lists have lengths equal to specified counts by truncating or padding + self.services: List[ServiceObservation] = services + while len(self.services) < num_services: + self.services.append(ServiceObservation(where=None)) + while len(self.services) > num_services: + truncated_service = self.services.pop() + msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" + _LOGGER.warning(msg) + + self.applications: List[ApplicationObservation] = applications + while len(self.applications) < num_applications: + self.applications.append(ApplicationObservation(where=None)) + while len(self.applications) > num_applications: + truncated_application = self.applications.pop() + msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" + _LOGGER.warning(msg) + + self.folders: List[FolderObservation] = folders + while len(self.folders) < num_folders: + self.folders.append( + FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access) + ) + while len(self.folders) > num_folders: + truncated_folder = self.folders.pop() + msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}" + _LOGGER.warning(msg) + + self.network_interfaces: List[NICObservation] = network_interfaces + while len(self.network_interfaces) < num_nics: + self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne)) + while len(self.network_interfaces) > num_nics: + truncated_nic = self.network_interfaces.pop() + msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" + _LOGGER.warning(msg) + + self.default_observation: ObsType = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "operating_status": 0, + "num_file_creations": 0, + "num_file_deletions": 0, + } + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status information about the host. + :rtype: ObsType + """ + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NICS"] = { + i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) + } + obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] + obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for host status. + :rtype: spaces.Space + """ + shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NICS": spaces.Dict( + {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} + ), + "num_file_creations": spaces.Discrete(4), + "num_file_deletions": spaces.Discrete(4), + } + return spaces.Dict(shape) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation: + """ + Create a host observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the host observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this host. + A typical location might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed host observation instance. + :rtype: HostObservation + """ + if parent_where is None: + where = ["network", "nodes", config.hostname] + else: + where = parent_where + ["nodes", config.hostname] + + # Pass down shared/common config items + for folder_config in config.folders: + folder_config.include_num_access = config.include_num_access + folder_config.num_files = config.num_files + for nic_config in config.network_interfaces: + nic_config.include_nmne = config.include_nmne + + services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] + applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] + folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] + nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] + + return cls( + where=where, + services=services, + applications=applications, + folders=folders, + network_interfaces=nics, + num_services=config.num_services, + num_applications=config.num_applications, + num_folders=config.num_folders, + num_files=config.num_files, + num_nics=config.num_nics, + include_nmne=config.include_nmne, + include_num_access=config.include_num_access, + ) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py index de83e03a..3be53112 100644 --- a/src/primaite/game/agent/observations/nic_observations.py +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -1,188 +1,157 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from __future__ import annotations + +from typing import Dict, Optional from gymnasium import spaces +from gymnasium.core import ObsType -from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -from primaite.simulator.network.nmne import CAPTURE_NMNE - -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame -class NicObservation(AbstractObservation): - """Observation of a Network Interface Card (NIC) in the network.""" +class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): + """Status information about a network interface within the simulation environment.""" - low_nmne_threshold: int = 0 - """The minimum number of malicious network events to be considered low.""" - med_nmne_threshold: int = 5 - """The minimum number of malicious network events to be considered medium.""" - high_nmne_threshold: int = 10 - """The minimum number of malicious network events to be considered high.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for NICObservation.""" - global CAPTURE_NMNE + nic_num: int + """Number of the network interface.""" + include_nmne: Optional[bool] = None + """Whether to include number of malicious network events (NMNE) in the observation.""" - @property - def default_observation(self) -> Dict: - """The default NIC observation dict.""" - data = {"nic_status": 0} - if CAPTURE_NMNE: - data.update({"NMNE": {"inbound": 0, "outbound": 0}}) - - return data - - def __init__( - self, - where: Optional[Tuple[str]] = None, - low_nmne_threshold: Optional[int] = 0, - med_nmne_threshold: Optional[int] = 5, - high_nmne_threshold: Optional[int] = 10, - ) -> None: - """Initialise NIC observation. - - :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical - example may look like this: - ['network','nodes',,'NICs',] - If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. - :type where: Optional[Tuple[str]], optional + def __init__(self, where: WhereType, include_nmne: bool) -> None: """ - super().__init__() - self.where: Optional[Tuple[str]] = where + Initialize a network interface observation instance. - global CAPTURE_NMNE - if CAPTURE_NMNE: - self.nmne_inbound_last_step: int = 0 - """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets - us find the difference.""" - self.nmne_outbound_last_step: int = 0 - """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets - us find the difference.""" - - if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold: - self._validate_nmne_categories( - low_nmne_threshold=low_nmne_threshold, - med_nmne_threshold=med_nmne_threshold, - high_nmne_threshold=high_nmne_threshold, - ) - - def _validate_nmne_categories( - self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10 - ): + :param where: Where in the simulation state dictionary to find the relevant information for this interface. + A typical location for a network interface might be + ['network', 'nodes', , 'NICs', ]. + :type where: WhereType + :param include_nmne: Flag to determine whether to include NMNE information in the observation. + :type include_nmne: bool """ - Validates the nmne threshold config. + self.where = where + self.include_nmne: bool = include_nmne - If the configuration is valid, the thresholds will be set, otherwise, an exception is raised. + self.default_observation: ObsType = {"nic_status": 0} + if self.include_nmne: + self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}}) - :param: low_nmne_threshold: The minimum number of malicious network events to be considered low - :param: med_nmne_threshold: The minimum number of malicious network events to be considered medium - :param: high_nmne_threshold: The minimum number of malicious network events to be considered high + def observe(self, state: Dict) -> ObsType: """ - if high_nmne_threshold <= med_nmne_threshold: - raise Exception( - f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater " - f"than medium nmne count ({med_nmne_threshold})" - ) + Generate observation based on the current state of the simulation. - if med_nmne_threshold <= low_nmne_threshold: - raise Exception( - f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater " - f"than low nmne count ({low_nmne_threshold})" - ) - - self.high_nmne_threshold = high_nmne_threshold - self.med_nmne_threshold = med_nmne_threshold - self.low_nmne_threshold = low_nmne_threshold - - def _categorise_mne_count(self, nmne_count: int) -> int: - """ - Categorise the number of Malicious Network Events (NMNEs) into discrete bins. - - This helps in classifying the severity or volume of MNEs into manageable levels for the agent. - - Bins are defined as follows: - - 0: No MNEs detected (0 events). - - 1: Low number of MNEs (default 1-5 events). - - 2: Moderate number of MNEs (default 6-10 events). - - 3: High number of MNEs (default more than 10 events). - - :param nmne_count: Number of MNEs detected. - :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. - """ - if nmne_count > self.high_nmne_threshold: - return 3 - elif nmne_count > self.med_nmne_threshold: - return 2 - elif nmne_count > self.low_nmne_threshold: - return 1 - return 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary + :param state: Simulation state dictionary. :type state: Dict - :return: Observation - :rtype: Dict + :return: Observation containing the status of the network interface and optionally NMNE information. + :rtype: ObsType """ - if self.where is None: - return self.default_observation nic_state = access_from_nested_dict(state, self.where) if nic_state is NOT_PRESENT_IN_STATE: return self.default_observation - else: - obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} - if CAPTURE_NMNE: - obs_dict.update({"NMNE": {}}) - direction_dict = nic_state["nmne"].get("direction", {}) - inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) - inbound_count = inbound_keywords.get("*", 0) - outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) - outbound_count = outbound_keywords.get("*", 0) - obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) - obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) - self.nmne_inbound_last_step = inbound_count - self.nmne_outbound_last_step = outbound_count - return obs_dict + + obs = {"nic_status": 1 if nic_state["enabled"] else 2} + if self.include_nmne: + obs.update({"NMNE": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + self.nmne_inbound_last_step = inbound_count + self.nmne_outbound_last_step = outbound_count + return obs @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for network interface status and NMNE information. + :rtype: spaces.Space + """ space = spaces.Dict({"nic_status": spaces.Discrete(3)}) - if CAPTURE_NMNE: + if self.include_nmne: space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) return space @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": - """Create NIC observation from a config. - - :param config: Dictionary containing the configuration for this NIC observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent - node. A typical location for a node ``where`` can be: ['network','nodes',] - :type parent_where: Optional[List[str]] - :return: Constructed NIC observation - :rtype: NicObservation + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: """ - low_nmne_threshold = None - med_nmne_threshold = None - high_nmne_threshold = None + Create a network interface observation from a configuration schema. - if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"): - threshold = game.options.thresholds["nmne"] + :param config: Configuration schema containing the necessary information for the network interface observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this NIC's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed network interface observation instance. + :rtype: NICObservation + """ + return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne) - low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None - med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None - high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None - return cls( - where=parent_where + ["NICs", config["nic_num"]], - low_nmne_threshold=low_nmne_threshold, - med_nmne_threshold=med_nmne_threshold, - high_nmne_threshold=high_nmne_threshold, - ) +class PortObservation(AbstractObservation, identifier="PORT"): + """Port observation, provides status information about a network port within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for PortObservation.""" + + port_id: int + """Identifier of the port, used for querying simulation state dictionary.""" + + def __init__(self, where: WhereType) -> None: + """ + Initialize a port observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this port. + A typical location for a port might be ['network', 'nodes', , 'NICs', ]. + :type where: WhereType + """ + self.where = where + self.default_observation: ObsType = {"operating_status": 0} + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the operating status of the port. + :rtype: ObsType + """ + port_state = access_from_nested_dict(state, self.where) + if port_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"operating_status": 1 if port_state["enabled"] else 2} + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for port status. + :rtype: spaces.Space + """ + return spaces.Dict({"operating_status": spaces.Discrete(3)}) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation: + """ + Create a port observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the port observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this port's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed port observation instance. + :rtype: PortObservation + """ + return cls(where=parent_where + ["NICs", config.port_id]) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index c702f8e2..0e63f440 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,1199 +1,18 @@ from __future__ import annotations -from ipaddress import IPv4Address -from typing import Any, Dict, Iterable, List, Optional +from typing import Dict, List from gymnasium import spaces from gymnasium.core import ObsType from primaite import getLogger -from primaite.game.agent.observations.observations import AbstractObservation -from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.game.agent.observations.firewall_observation import FirewallObservation +from primaite.game.agent.observations.host_observations import HostObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.observations.router_observation import RouterObservation _LOGGER = getLogger(__name__) -WhereType = Iterable[str | int] | None - - -class ServiceObservation(AbstractObservation, identifier="SERVICE"): - """Service observation, shows status of a service in the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for ServiceObservation.""" - - service_name: str - """Name of the service, used for querying simulation state dictionary""" - - def __init__(self, where: WhereType) -> None: - """ - Initialize a service observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this service. - A typical location for a service might be ['network', 'nodes', , 'services', ]. - :type where: WhereType - """ - self.where = where - self.default_observation = {"operating_status": 0, "health_status": 0} - - def observe(self, state: Dict) -> ObsType: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the operating status and health status of the service. - :rtype: Any - """ - service_state = access_from_nested_dict(state, self.where) - if service_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return { - "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], - } - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for service status. - :rtype: spaces.Space - """ - return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: - """ - Create a service observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the service observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this service's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed service observation instance. - :rtype: ServiceObservation - """ - return cls(where=parent_where + ["services", config.service_name]) - - -class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): - """Application observation, shows the status of an application within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for ApplicationObservation.""" - - application_name: str - """Name of the application, used for querying simulation state dictionary""" - - def __init__(self, where: WhereType) -> None: - """ - Initialise an application observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this application. - A typical location for an application might be - ['network', 'nodes', , 'applications', ]. - :type where: WhereType - """ - self.where = where - self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Obs containing the operating status, health status, and number of executions of the application. - :rtype: Any - """ - application_state = access_from_nested_dict(state, self.where) - if application_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return { - "operating_status": application_state["operating_state"], - "health_status": application_state["health_state_visible"], - "num_executions": application_state["num_executions"], - } - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for application status. - :rtype: spaces.Space - """ - return spaces.Dict( - { - "operating_status": spaces.Discrete(7), - "health_status": spaces.Discrete(5), - "num_executions": spaces.Discrete(4), - } - ) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation: - """ - Create an application observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the application observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this application's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed application observation instance. - :rtype: ApplicationObservation - """ - return cls(where=parent_where + ["applications", config.application_name]) - - -class FileObservation(AbstractObservation, identifier="FILE"): - """File observation, provides status information about a file within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for FileObservation.""" - - file_name: str - """Name of the file, used for querying simulation state dictionary.""" - include_num_access: Optional[bool] = None - """Whether to include the number of accesses to the file in the observation.""" - - def __init__(self, where: WhereType, include_num_access: bool) -> None: - """ - Initialize a file observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this file. - A typical location for a file might be - ['network', 'nodes', , 'file_system', 'folder', , 'files', ]. - :type where: WhereType - :param include_num_access: Whether to include the number of accesses to the file in the observation. - :type include_num_access: bool - """ - self.where: WhereType = where - self.include_num_access: bool = include_num_access - - self.default_observation: ObsType = {"health_status": 0} - if self.include_num_access: - self.default_observation["num_access"] = 0 - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the health status of the file and optionally the number of accesses. - :rtype: Any - """ - file_state = access_from_nested_dict(state, self.where) - if file_state is NOT_PRESENT_IN_STATE: - return self.default_observation - obs = {"health_status": file_state["visible_status"]} - if self.include_num_access: - obs["num_access"] = file_state["num_access"] - # raise NotImplementedError("TODO: need to fix num_access to use thresholds instead of raw value.") - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for file status. - :rtype: spaces.Space - """ - space = {"health_status": spaces.Discrete(6)} - if self.include_num_access: - space["num_access"] = spaces.Discrete(4) - return spaces.Dict(space) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FileObservation: - """ - Create a file observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the file observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this file's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed file observation instance. - :rtype: FileObservation - """ - return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access) - - -class FolderObservation(AbstractObservation, identifier="FOLDER"): - """Folder observation, provides status information about a folder within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for FolderObservation.""" - - folder_name: str - """Name of the folder, used for querying simulation state dictionary.""" - files: List[FileObservation.ConfigSchema] = [] - """List of file configurations within the folder.""" - num_files: Optional[int] = None - """Number of spaces for file observations in this folder.""" - include_num_access: Optional[bool] = None - """Whether files in this folder should include the number of accesses in their observation.""" - - def __init__( - self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool - ) -> None: - """ - Initialize a folder observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this folder. - A typical location for a folder might be ['network', 'nodes', , 'folders', ]. - :type where: WhereType - :param files: List of file observation instances within the folder. - :type files: Iterable[FileObservation] - :param num_files: Number of files expected in the folder. - :type num_files: int - :param include_num_access: Whether to include the number of accesses to files in the observation. - :type include_num_access: bool - """ - self.where: WhereType = where - - self.files: List[FileObservation] = files - while len(self.files) < num_files: - self.files.append(FileObservation(where=None, include_num_access=include_num_access)) - while len(self.files) > num_files: - truncated_file = self.files.pop() - msg = f"Too many files in folder observation. Truncating file {truncated_file}" - _LOGGER.warning(msg) - - self.default_observation = { - "health_status": 0, - "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the health status of the folder and status of files within the folder. - :rtype: Any - """ - folder_state = access_from_nested_dict(state, self.where) - if folder_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - health_status = folder_state["health_status"] - - obs = {} - - obs["health_status"] = health_status - obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} - - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for folder status. - :rtype: spaces.Space - """ - return spaces.Dict( - { - "health_status": spaces.Discrete(6), - "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), - } - ) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FolderObservation: - """ - Create a folder observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the folder observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this folder's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed folder observation instance. - :rtype: FolderObservation - """ - where = parent_where + ["folders", config.folder_name] - - # pass down shared/common config items - for file_config in config.files: - file_config.include_num_access = config.include_num_access - - files = [FileObservation.from_config(config=f, parent_where=where) for f in config.files] - return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access) - - -class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): - """Status information about a network interface within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for NICObservation.""" - - nic_num: int - """Number of the network interface.""" - include_nmne: Optional[bool] = None - """Whether to include number of malicious network events (NMNE) in the observation.""" - - def __init__(self, where: WhereType, include_nmne: bool) -> None: - """ - Initialize a network interface observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this interface. - A typical location for a network interface might be - ['network', 'nodes', , 'NICs', ]. - :type where: WhereType - :param include_nmne: Flag to determine whether to include NMNE information in the observation. - :type include_nmne: bool - """ - self.where = where - self.include_nmne: bool = include_nmne - - self.default_observation: ObsType = {"nic_status": 0} - if self.include_nmne: - self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}}) - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the status of the network interface and optionally NMNE information. - :rtype: Any - """ - nic_state = access_from_nested_dict(state, self.where) - - if nic_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {"nic_status": 1 if nic_state["enabled"] else 2} - if self.include_nmne: - obs.update({"NMNE": {}}) - direction_dict = nic_state["nmne"].get("direction", {}) - inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) - inbound_count = inbound_keywords.get("*", 0) - outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) - outbound_count = outbound_keywords.get("*", 0) - obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) - obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) - self.nmne_inbound_last_step = inbound_count - self.nmne_outbound_last_step = outbound_count - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for network interface status and NMNE information. - :rtype: spaces.Space - """ - space = spaces.Dict({"nic_status": spaces.Discrete(3)}) - - if self.include_nmne: - space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) - - return space - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NICObservation: - """ - Create a network interface observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the network interface observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this NIC's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed network interface observation instance. - :rtype: NICObservation - """ - return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne) - - -class HostObservation(AbstractObservation, identifier="HOST"): - """Host observation, provides status information about a host within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for HostObservation.""" - - hostname: str - """Hostname of the host, used for querying simulation state dictionary.""" - services: List[ServiceObservation.ConfigSchema] = [] - """List of services to observe on the host.""" - applications: List[ApplicationObservation.ConfigSchema] = [] - """List of applications to observe on the host.""" - folders: List[FolderObservation.ConfigSchema] = [] - """List of folders to observe on the host.""" - network_interfaces: List[NICObservation.ConfigSchema] = [] - """List of network interfaces to observe on the host.""" - num_services: Optional[int] = None - """Number of spaces for service observations on this host.""" - num_applications: Optional[int] = None - """Number of spaces for application observations on this host.""" - num_folders: Optional[int] = None - """Number of spaces for folder observations on this host.""" - num_files: Optional[int] = None - """Number of spaces for file observations on this host.""" - num_nics: Optional[int] = None - """Number of spaces for network interface observations on this host.""" - include_nmne: Optional[bool] = None - """Whether network interface observations should include number of malicious network events.""" - include_num_access: Optional[bool] = None - """Whether to include the number of accesses to files observations on this host.""" - - def __init__( - self, - where: WhereType, - services: List[ServiceObservation], - applications: List[ApplicationObservation], - folders: List[FolderObservation], - network_interfaces: List[NICObservation], - num_services: int, - num_applications: int, - num_folders: int, - num_files: int, - num_nics: int, - include_nmne: bool, - include_num_access: bool, - ) -> None: - """ - Initialize a host observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this host. - A typical location for a host might be ['network', 'nodes', ]. - :type where: WhereType - :param services: List of service observations on the host. - :type services: List[ServiceObservation] - :param applications: List of application observations on the host. - :type applications: List[ApplicationObservation] - :param folders: List of folder observations on the host. - :type folders: List[FolderObservation] - :param network_interfaces: List of network interface observations on the host. - :type network_interfaces: List[NICObservation] - :param num_services: Number of services to observe. - :type num_services: int - :param num_applications: Number of applications to observe. - :type num_applications: int - :param num_folders: Number of folders to observe. - :type num_folders: int - :param num_files: Number of files. - :type num_files: int - :param num_nics: Number of network interfaces. - :type num_nics: int - :param include_nmne: Flag to include network metrics and errors. - :type include_nmne: bool - :param include_num_access: Flag to include the number of accesses to files. - :type include_num_access: bool - """ - self.where: WhereType = where - - # Ensure lists have lengths equal to specified counts by truncating or padding - self.services: List[ServiceObservation] = services - while len(self.services) < num_services: - self.services.append(ServiceObservation(where=None)) - while len(self.services) > num_services: - truncated_service = self.services.pop() - msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warning(msg) - - self.applications: List[ApplicationObservation] = applications - while len(self.applications) < num_applications: - self.applications.append(ApplicationObservation(where=None)) - while len(self.applications) > num_applications: - truncated_application = self.applications.pop() - msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}" - _LOGGER.warning(msg) - - self.folders: List[FolderObservation] = folders - while len(self.folders) < num_folders: - self.folders.append( - FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access) - ) - while len(self.folders) > num_folders: - truncated_folder = self.folders.pop() - msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}" - _LOGGER.warning(msg) - - self.network_interfaces: List[NICObservation] = network_interfaces - while len(self.network_interfaces) < num_nics: - self.network_interfaces.append(NICObservation(where=None, include_nmne=include_nmne)) - while len(self.network_interfaces) > num_nics: - truncated_nic = self.network_interfaces.pop() - msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}" - _LOGGER.warning(msg) - - self.default_observation: ObsType = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, - "operating_status": 0, - "num_file_creations": 0, - "num_file_deletions": 0, - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the status information about the host. - :rtype: Any - """ - node_state = access_from_nested_dict(state, self.where) - if node_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - obs["operating_status"] = node_state["operating_state"] - obs["NICS"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } - obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] - obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for host status. - :rtype: spaces.Space - """ - shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(5), - "NICS": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), - "num_file_creations": spaces.Discrete(4), - "num_file_deletions": spaces.Discrete(4), - } - return spaces.Dict(shape) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = None) -> HostObservation: - """ - Create a host observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the host observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this host. - A typical location might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed host observation instance. - :rtype: HostObservation - """ - if parent_where is None: - where = ["network", "nodes", config.hostname] - else: - where = parent_where + ["nodes", config.hostname] - - # Pass down shared/common config items - for folder_config in config.folders: - folder_config.include_num_access = config.include_num_access - folder_config.num_files = config.num_files - for nic_config in config.network_interfaces: - nic_config.include_nmne = config.include_nmne - - services = [ServiceObservation.from_config(config=c, parent_where=where) for c in config.services] - applications = [ApplicationObservation.from_config(config=c, parent_where=where) for c in config.applications] - folders = [FolderObservation.from_config(config=c, parent_where=where) for c in config.folders] - nics = [NICObservation.from_config(config=c, parent_where=where) for c in config.network_interfaces] - - return cls( - where=where, - services=services, - applications=applications, - folders=folders, - network_interfaces=nics, - num_services=config.num_services, - num_applications=config.num_applications, - num_folders=config.num_folders, - num_files=config.num_files, - num_nics=config.num_nics, - include_nmne=config.include_nmne, - include_num_access=config.include_num_access, - ) - - -class PortObservation(AbstractObservation, identifier="PORT"): - """Port observation, provides status information about a network port within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for PortObservation.""" - - port_id: int - """Identifier of the port, used for querying simulation state dictionary.""" - - def __init__(self, where: WhereType) -> None: - """ - Initialize a port observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this port. - A typical location for a port might be ['network', 'nodes', , 'NICs', ]. - :type where: WhereType - """ - self.where = where - self.default_observation: ObsType = {"operating_status": 0} - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the operating status of the port. - :rtype: Any - """ - port_state = access_from_nested_dict(state, self.where) - if port_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return {"operating_status": 1 if port_state["enabled"] else 2} - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for port status. - :rtype: spaces.Space - """ - return spaces.Dict({"operating_status": spaces.Discrete(3)}) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> PortObservation: - """ - Create a port observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the port observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this port's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed port observation instance. - :rtype: PortObservation - """ - return cls(where=parent_where + ["NICs", config.port_id]) - - -class ACLObservation(AbstractObservation, identifier="ACL"): - """ACL observation, provides information about access control lists within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for ACLObservation.""" - - ip_list: Optional[List[IPv4Address]] = None - """List of IP addresses.""" - wildcard_list: Optional[List[str]] = None - """List of wildcard strings.""" - port_list: Optional[List[int]] = None - """List of port numbers.""" - protocol_list: Optional[List[str]] = None - """List of protocol names.""" - num_rules: Optional[int] = None - """Number of ACL rules.""" - - def __init__( - self, - where: WhereType, - num_rules: int, - ip_list: List[IPv4Address], - wildcard_list: List[str], - port_list: List[int], - protocol_list: List[str], - ) -> None: - """ - Initialize an ACL observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this ACL. - :type where: WhereType - :param num_rules: Number of ACL rules. - :type num_rules: int - :param ip_list: List of IP addresses. - :type ip_list: List[IPv4Address] - :param wildcard_list: List of wildcard strings. - :type wildcard_list: List[str] - :param port_list: List of port numbers. - :type port_list: List[int] - :param protocol_list: List of protocol names. - :type protocol_list: List[str] - """ - self.where = where - self.num_rules: int = num_rules - self.ip_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(ip_list)} - self.wildcard_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(wildcard_list)} - self.port_to_id: Dict[int, int] = {i + 2: p for i, p in enumerate(port_list)} - self.protocol_to_id: Dict[str, int] = {i + 2: p for i, p in enumerate(protocol_list)} - self.default_observation: Dict = { - i - + 1: { - "position": i, - "permission": 0, - "source_ip_id": 0, - "source_wildcard_id": 0, - "source_port_id": 0, - "dest_ip_id": 0, - "dest_wildcard_id": 0, - "dest_port_id": 0, - "protocol_id": 0, - } - for i in range(self.num_rules) - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing ACL rules. - :rtype: Any - """ - acl_state: Dict = access_from_nested_dict(state, self.where) - if acl_state is NOT_PRESENT_IN_STATE: - return self.default_observation - obs = {} - acl_items = dict(acl_state.items()) - i = 1 # don't show rule 0 for compatibility reasons. - while i < self.num_rules + 1: - rule_state = acl_items[i] - if rule_state is None: - obs[i] = { - "position": i - 1, - "permission": 0, - "source_ip_id": 0, - "source_wildcard_id": 0, - "source_port_id": 0, - "dest_ip_id": 0, - "dest_wildcard_id": 0, - "dest_port_id": 0, - "protocol_id": 0, - } - else: - src_ip = rule_state["src_ip_address"] - src_node_id = self.ip_to_id.get(src_ip, 1) - dst_ip = rule_state["dst_ip_address"] - dst_node_ip = self.ip_to_id.get(dst_ip, 1) - src_wildcard = rule_state["source_wildcard_id"] - src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1) - dst_wildcard = rule_state["dest_wildcard_id"] - dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1) - src_port = rule_state["source_port_id"] - src_port_id = self.port_to_id.get(src_port, 1) - dst_port = rule_state["dest_port_id"] - dst_port_id = self.port_to_id.get(dst_port, 1) - protocol = rule_state["protocol"] - protocol_id = self.protocol_to_id.get(protocol, 1) - obs[i] = { - "position": i - 1, - "permission": rule_state["action"], - "source_ip_id": src_node_id, - "source_wildcard_id": src_wildcard_id, - "source_port_id": src_port_id, - "dest_ip_id": dst_node_ip, - "dest_wildcard_id": dst_wildcard_id, - "dest_port_id": dst_port_id, - "protocol_id": protocol_id, - } - i += 1 - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for ACL rules. - :rtype: spaces.Space - """ - return spaces.Dict( - { - i - + 1: spaces.Dict( - { - "position": spaces.Discrete(self.num_rules), - "permission": spaces.Discrete(3), - # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), - "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), - "source_port_id": spaces.Discrete(len(self.port_to_id) + 2), - "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), - "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), - "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), - "protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2), - } - ) - for i in range(self.num_rules) - } - ) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation: - """ - Create an ACL observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the ACL observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this ACL's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed ACL observation instance. - :rtype: ACLObservation - """ - return cls( - where=parent_where + ["acl", "acl"], - num_rules=config.num_rules, - ip_list=config.ip_list, - wildcard_list=config.wildcard_list, - port_list=config.port_list, - protocol_list=config.protocol_list, - ) - - -class RouterObservation(AbstractObservation, identifier="ROUTER"): - """Router observation, provides status information about a router within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for RouterObservation.""" - - hostname: str - """Hostname of the router, used for querying simulation state dictionary.""" - ports: Optional[List[PortObservation.ConfigSchema]] = None - """Configuration of port observations for this router.""" - num_ports: Optional[int] = None - """Number of port observations configured for this router.""" - acl: Optional[ACLObservation.ConfigSchema] = None - """Configuration of ACL observation on this router.""" - ip_list: Optional[List[str]] = None - """List of IP addresses for encoding ACLs.""" - wildcard_list: Optional[List[str]] = None - """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None - """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None - """List of protocols for encoding ACLs.""" - num_rules: Optional[int] = None - """Number of rules ACL rules to show.""" - - def __init__( - self, - where: WhereType, - ports: List[PortObservation], - num_ports: int, - acl: ACLObservation, - ) -> None: - """ - Initialize a router observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this router. - A typical location for a router might be ['network', 'nodes', ]. - :type where: WhereType - :param ports: List of port observations representing the ports of the router. - :type ports: List[PortObservation] - :param num_ports: Number of ports for the router. - :type num_ports: int - :param acl: ACL observation representing the access control list of the router. - :type acl: ACLObservation - """ - self.where: WhereType = where - self.ports: List[PortObservation] = ports - self.acl: ACLObservation = acl - self.num_ports: int = num_ports - - while len(self.ports) < num_ports: - self.ports.append(PortObservation(where=None)) - while len(self.ports) > num_ports: - self.ports.pop() - msg = "Too many ports in router observation. Truncating." - _LOGGER.warning(msg) - - self.default_observation = { - "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, - "ACL": self.acl.default_observation, - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the status of ports and ACL configuration of the router. - :rtype: Any - """ - router_state = access_from_nested_dict(state, self.where) - if router_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} - obs["ACL"] = self.acl.observe(state) - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for router status. - :rtype: spaces.Space - """ - return spaces.Dict( - {"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space} - ) - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation: - """ - Create a router observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the router observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this router's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed router observation instance. - :rtype: RouterObservation - """ - where = parent_where + ["nodes", config.hostname] - - if config.acl is None: - config.acl = ACLObservation.ConfigSchema() - if config.acl.num_rules is None: - config.acl.num_rules = config.num_rules - if config.acl.ip_list is None: - config.acl.ip_list = config.ip_list - if config.acl.wildcard_list is None: - config.acl.wildcard_list = config.wildcard_list - if config.acl.port_list is None: - config.acl.port_list = config.port_list - if config.acl.protocol_list is None: - config.acl.protocol_list = config.protocol_list - - if config.ports is None: - config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)] - - ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] - acl = ACLObservation.from_config(config=config.acl, parent_where=where) - return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) - - -class FirewallObservation(AbstractObservation, identifier="FIREWALL"): - """Firewall observation, provides status information about a firewall within the simulation environment.""" - - class ConfigSchema(AbstractObservation.ConfigSchema): - """Configuration schema for FirewallObservation.""" - - hostname: str - """Hostname of the firewall node, used for querying simulation state dictionary.""" - ip_list: Optional[List[str]] = None - """List of IP addresses for encoding ACLs.""" - wildcard_list: Optional[List[str]] = None - """List of IP wildcards for encoding ACLs.""" - port_list: Optional[List[int]] = None - """List of ports for encoding ACLs.""" - protocol_list: Optional[List[str]] = None - """List of protocols for encoding ACLs.""" - num_rules: Optional[int] = None - """Number of rules ACL rules to show.""" - - def __init__( - self, - where: WhereType, - ip_list: List[str], - wildcard_list: List[str], - port_list: List[int], - protocol_list: List[str], - num_rules: int, - ) -> None: - """ - Initialize a firewall observation instance. - - :param where: Where in the simulation state dictionary to find the relevant information for this firewall. - A typical location for a firewall might be ['network', 'nodes', ]. - :type where: WhereType - :param ip_list: List of IP addresses. - :type ip_list: List[str] - :param wildcard_list: List of wildcard rules. - :type wildcard_list: List[str] - :param port_list: List of port numbers. - :type port_list: List[int] - :param protocol_list: List of protocol types. - :type protocol_list: List[str] - :param num_rules: Number of rules configured in the firewall. - :type num_rules: int - """ - self.where: WhereType = where - - self.ports: List[PortObservation] = [ - PortObservation(where=self.where + ["port", port_num]) for port_num in (1, 2, 3) - ] - # TODO: check what the port nums are for firewall. - - self.internal_inbound_acl = ACLObservation( - where=self.where + ["acl", "internal", "inbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.internal_outbound_acl = ACLObservation( - where=self.where + ["acl", "internal", "outbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.dmz_inbound_acl = ACLObservation( - where=self.where + ["acl", "dmz", "inbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.dmz_outbound_acl = ACLObservation( - where=self.where + ["acl", "dmz", "outbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.external_inbound_acl = ACLObservation( - where=self.where + ["acl", "external", "inbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - self.external_outbound_acl = ACLObservation( - where=self.where + ["acl", "external", "outbound"], - num_rules=num_rules, - ip_list=ip_list, - wildcard_list=wildcard_list, - port_list=port_list, - protocol_list=protocol_list, - ) - - self.default_observation = { - "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.default_observation, - "OUTBOUND": self.internal_outbound_acl.default_observation, - }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.default_observation, - "OUTBOUND": self.dmz_outbound_acl.default_observation, - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.default_observation, - "OUTBOUND": self.external_outbound_acl.default_observation, - }, - } - - def observe(self, state: Dict) -> Any: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary. - :type state: Dict - :return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic. - :rtype: Any - """ - obs = { - "PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)}, - "INTERNAL": { - "INBOUND": self.internal_inbound_acl.observe(state), - "OUTBOUND": self.internal_outbound_acl.observe(state), - }, - "DMZ": { - "INBOUND": self.dmz_inbound_acl.observe(state), - "OUTBOUND": self.dmz_outbound_acl.observe(state), - }, - "EXTERNAL": { - "INBOUND": self.external_inbound_acl.observe(state), - "OUTBOUND": self.external_outbound_acl.observe(state), - }, - } - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Gymnasium space representing the observation space for firewall status. - :rtype: spaces.Space - """ - space = spaces.Dict( - { - "PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), - "INTERNAL": spaces.Dict( - { - "INBOUND": self.internal_inbound_acl.space, - "OUTBOUND": self.internal_outbound_acl.space, - } - ), - "DMZ": spaces.Dict( - { - "INBOUND": self.dmz_inbound_acl.space, - "OUTBOUND": self.dmz_outbound_acl.space, - } - ), - "EXTERNAL": spaces.Dict( - { - "INBOUND": self.external_inbound_acl.space, - "OUTBOUND": self.external_outbound_acl.space, - } - ), - } - ) - return space - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> FirewallObservation: - """ - Create a firewall observation from a configuration schema. - - :param config: Configuration schema containing the necessary information for the firewall observation. - :type config: ConfigSchema - :param parent_where: Where in the simulation state dictionary to find the information about this firewall's - parent node. A typical location for a node might be ['network', 'nodes', ]. - :type parent_where: WhereType, optional - :return: Constructed firewall observation instance. - :rtype: FirewallObservation - """ - where = parent_where + ["nodes", config.hostname] - return cls( - where=where, - ip_list=config.ip_list, - wildcard_list=config.wildcard_list, - port_list=config.port_list, - protocol_list=config.protocol_list, - num_rules=config.num_rules, - ) - class NodesObservation(AbstractObservation, identifier="NODES"): """Nodes observation, provides status information about nodes within the simulation environment.""" @@ -1266,14 +85,14 @@ class NodesObservation(AbstractObservation, identifier="NODES"): **{f"FIREWALL{i}": firewall.default_observation for i, firewall in enumerate(self.firewalls)}, } - def observe(self, state: Dict) -> Any: + def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. :param state: Simulation state dictionary. :type state: Dict :return: Observation containing status information about nodes. - :rtype: Any + :rtype: ObsType """ obs = { **{f"HOST{i}": host.observe(state) for i, host in enumerate(self.hosts)}, @@ -1300,7 +119,7 @@ class NodesObservation(AbstractObservation, identifier="NODES"): return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation: """ Create a nodes observation from a configuration schema. diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py index dc41e8e5..08871072 100644 --- a/src/primaite/game/agent/observations/observations.py +++ b/src/primaite/game/agent/observations/observations.py @@ -1,24 +1,23 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod -from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Type +from typing import Any, Dict, Iterable, Type from gymnasium import spaces from pydantic import BaseModel, ConfigDict from primaite import getLogger -from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE _LOGGER = getLogger(__name__) -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame +WhereType = Iterable[str | int] | None class AbstractObservation(ABC): """Abstract class for an observation space component.""" class ConfigSchema(ABC, BaseModel): + """Config schema for observations.""" + model_config = ConfigDict(extra="forbid") _registry: Dict[str, Type["AbstractObservation"]] = {} @@ -61,269 +60,271 @@ class AbstractObservation(ABC): @classmethod def from_config(cls, cfg: Dict) -> "AbstractObservation": """Create this observation space component form a serialised format.""" - ObservationType = cls._registry[cfg['type']] + ObservationType = cls._registry[cfg["type"]] return ObservationType.from_config(cfg=cfg) -# class LinkObservation(AbstractObservation): -# """Observation of a link in the network.""" +''' +class LinkObservation(AbstractObservation): + """Observation of a link in the network.""" -# default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} -# "Default observation is what should be returned when the link doesn't exist." + default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} + "Default observation is what should be returned when the link doesn't exist." -# def __init__(self, where: Optional[Tuple[str]] = None) -> None: -# """Initialise link observation. + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise link observation. -# :param where: Store information about where in the simulation state dictionary to find the relevant information. -# Optional. If None, this corresponds that the file does not exist and the observation will be populated with -# zeroes. + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. -# A typical location for a service looks like this: -# `['network','nodes',,'servics', ]` -# :type where: Optional[List[str]] -# """ -# super().__init__() -# self.where: Optional[Tuple[str]] = where + A typical location for a service looks like this: + `['network','nodes',,'servics', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where -# def observe(self, state: Dict) -> Dict: -# """Generate observation based on the current state of the simulation. + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. -# :param state: Simulation state dictionary -# :type state: Dict -# :return: Observation -# :rtype: Dict -# """ -# if self.where is None: -# return self.default_observation + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation -# link_state = access_from_nested_dict(state, self.where) -# if link_state is NOT_PRESENT_IN_STATE: -# return self.default_observation + link_state = access_from_nested_dict(state, self.where) + if link_state is NOT_PRESENT_IN_STATE: + return self.default_observation -# bandwidth = link_state["bandwidth"] -# load = link_state["current_load"] -# if load == 0: -# utilisation_category = 0 -# else: -# utilisation_fraction = load / bandwidth -# # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% -# utilisation_category = int(utilisation_fraction * 9) + 1 + bandwidth = link_state["bandwidth"] + load = link_state["current_load"] + if load == 0: + utilisation_category = 0 + else: + utilisation_fraction = load / bandwidth + # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% + utilisation_category = int(utilisation_fraction * 9) + 1 -# # TODO: once the links support separte load per protocol, this needs amendment to reflect that. -# return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} + # TODO: once the links support separte load per protocol, this needs amendment to reflect that. + return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} -# @property -# def space(self) -> spaces.Space: -# """Gymnasium space object describing the observation space shape. + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. -# :return: Gymnasium space -# :rtype: spaces.Space -# """ -# return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) -# @classmethod -# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": -# """Create link observation from a config. + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": + """Create link observation from a config. -# :param config: Dictionary containing the configuration for this link observation. -# :type config: Dict -# :param game: Reference to the PrimaiteGame object that spawned this observation. -# :type game: PrimaiteGame -# :return: Constructed link observation -# :rtype: LinkObservation -# """ -# return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) + :param config: Dictionary containing the configuration for this link observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Constructed link observation + :rtype: LinkObservation + """ + return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) -# class AclObservation(AbstractObservation): -# """Observation of an Access Control List (ACL) in the network.""" +class AclObservation(AbstractObservation): + """Observation of an Access Control List (ACL) in the network.""" -# # TODO: should where be optional, and we can use where=None to pad the observation space? -# # definitely the current approach does not support tracking files that aren't specified by name, for example -# # if a file is created at runtime, we have currently got no way of telling the observation space to track it. -# # this needs adding, but not for the MVP. -# def __init__( -# self, -# node_ip_to_id: Dict[str, int], -# ports: List[int], -# protocols: List[str], -# where: Optional[Tuple[str]] = None, -# num_rules: int = 10, -# ) -> None: -# """Initialise ACL observation. + # TODO: should where be optional, and we can use where=None to pad the observation space? + # definitely the current approach does not support tracking files that aren't specified by name, for example + # if a file is created at runtime, we have currently got no way of telling the observation space to track it. + # this needs adding, but not for the MVP. + def __init__( + self, + node_ip_to_id: Dict[str, int], + ports: List[int], + protocols: List[str], + where: Optional[Tuple[str]] = None, + num_rules: int = 10, + ) -> None: + """Initialise ACL observation. -# :param node_ip_to_id: Mapping between IP address and ID. -# :type node_ip_to_id: Dict[str, int] -# :param ports: List of ports which are part of the game that define the ordering when converting to an ID -# :type ports: List[int] -# :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID -# :type protocols: list[str] -# :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical -# example may look like this: -# ['network','nodes',,'acl','acl'] -# :type where: Optional[Tuple[str]], optional -# :param num_rules: , defaults to 10 -# :type num_rules: int, optional -# """ -# super().__init__() -# self.where: Optional[Tuple[str]] = where -# self.num_rules: int = num_rules -# self.node_to_id: Dict[str, int] = node_ip_to_id -# "List of node IP addresses, order in this list determines how they are converted to an ID" -# self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} -# "List of ports which are part of the game that define the ordering when converting to an ID" -# self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} -# "List of protocols which are part of the game, defines ordering when converting to an ID" -# self.default_observation: Dict = { -# i -# + 1: { -# "position": i, -# "permission": 0, -# "source_node_id": 0, -# "source_port": 0, -# "dest_node_id": 0, -# "dest_port": 0, -# "protocol": 0, -# } -# for i in range(self.num_rules) -# } + :param node_ip_to_id: Mapping between IP address and ID. + :type node_ip_to_id: Dict[str, int] + :param ports: List of ports which are part of the game that define the ordering when converting to an ID + :type ports: List[int] + :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID + :type protocols: list[str] + :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical + example may look like this: + ['network','nodes',,'acl','acl'] + :type where: Optional[Tuple[str]], optional + :param num_rules: , defaults to 10 + :type num_rules: int, optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + self.num_rules: int = num_rules + self.node_to_id: Dict[str, int] = node_ip_to_id + "List of node IP addresses, order in this list determines how they are converted to an ID" + self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} + "List of ports which are part of the game that define the ordering when converting to an ID" + self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} + "List of protocols which are part of the game, defines ordering when converting to an ID" + self.default_observation: Dict = { + i + + 1: { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + for i in range(self.num_rules) + } -# def observe(self, state: Dict) -> Dict: -# """Generate observation based on the current state of the simulation. + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. -# :param state: Simulation state dictionary -# :type state: Dict -# :return: Observation -# :rtype: Dict -# """ -# if self.where is None: -# return self.default_observation -# acl_state: Dict = access_from_nested_dict(state, self.where) -# if acl_state is NOT_PRESENT_IN_STATE: -# return self.default_observation + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation -# # TODO: what if the ACL has more rules than num of max rules for obs space -# obs = {} -# acl_items = dict(acl_state.items()) -# i = 1 # don't show rule 0 for compatibility reasons. -# while i < self.num_rules + 1: -# rule_state = acl_items[i] -# if rule_state is None: -# obs[i] = { -# "position": i - 1, -# "permission": 0, -# "source_node_id": 0, -# "source_port": 0, -# "dest_node_id": 0, -# "dest_port": 0, -# "protocol": 0, -# } -# else: -# src_ip = rule_state["src_ip_address"] -# src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] -# dst_ip = rule_state["dst_ip_address"] -# dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] -# src_port = rule_state["src_port"] -# src_port_id = 1 if src_port is None else self.port_to_id[src_port] -# dst_port = rule_state["dst_port"] -# dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] -# protocol = rule_state["protocol"] -# protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] -# obs[i] = { -# "position": i - 1, -# "permission": rule_state["action"], -# "source_node_id": src_node_id, -# "source_port": src_port_id, -# "dest_node_id": dst_node_ip, -# "dest_port": dst_port_id, -# "protocol": protocol_id, -# } -# i += 1 -# return obs + # TODO: what if the ACL has more rules than num of max rules for obs space + obs = {} + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] + if rule_state is None: + obs[i] = { + "position": i - 1, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + else: + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, + "permission": rule_state["action"], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, + } + i += 1 + return obs -# @property -# def space(self) -> spaces.Space: -# """Gymnasium space object describing the observation space shape. + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. -# :return: Gymnasium space -# :rtype: spaces.Space -# """ -# return spaces.Dict( -# { -# i -# + 1: spaces.Dict( -# { -# "position": spaces.Discrete(self.num_rules), -# "permission": spaces.Discrete(3), -# # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) -# "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), -# "source_port": spaces.Discrete(len(self.port_to_id) + 2), -# "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), -# "dest_port": spaces.Discrete(len(self.port_to_id) + 2), -# "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), -# } -# ) -# for i in range(self.num_rules) -# } -# ) + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "source_port": spaces.Discrete(len(self.port_to_id) + 2), + "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), + } + ) + for i in range(self.num_rules) + } + ) -# @classmethod -# def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": -# """Generate ACL observation from a config. + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": + """Generate ACL observation from a config. -# :param config: Dictionary containing the configuration for this ACL observation. -# :type config: Dict -# :param game: Reference to the PrimaiteGame object that spawned this observation. -# :type game: PrimaiteGame -# :return: Observation object -# :rtype: AclObservation -# """ -# max_acl_rules = config["options"]["max_acl_rules"] -# node_ip_to_idx = {} -# for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): -# node_ref = ip_map_config["node_hostname"] -# nic_num = ip_map_config["nic_num"] -# node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] -# nic_obj = node_obj.network_interface[nic_num] -# node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 + :param config: Dictionary containing the configuration for this ACL observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Observation object + :rtype: AclObservation + """ + max_acl_rules = config["options"]["max_acl_rules"] + node_ip_to_idx = {} + for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): + node_ref = ip_map_config["node_hostname"] + nic_num = ip_map_config["nic_num"] + node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] + nic_obj = node_obj.network_interface[nic_num] + node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 -# router_hostname = config["router_hostname"] -# return cls( -# node_ip_to_id=node_ip_to_idx, -# ports=game.options.ports, -# protocols=game.options.protocols, -# where=["network", "nodes", router_hostname, "acl", "acl"], -# num_rules=max_acl_rules, -# ) + router_hostname = config["router_hostname"] + return cls( + node_ip_to_id=node_ip_to_idx, + ports=game.options.ports, + protocols=game.options.protocols, + where=["network", "nodes", router_hostname, "acl", "acl"], + num_rules=max_acl_rules, + ) -# class NullObservation(AbstractObservation): -# """Null observation, returns a single 0 value for the observation space.""" +class NullObservation(AbstractObservation): + """Null observation, returns a single 0 value for the observation space.""" -# def __init__(self, where: Optional[List[str]] = None): -# """Initialise null observation.""" -# self.default_observation: Dict = {} + def __init__(self, where: Optional[List[str]] = None): + """Initialise null observation.""" + self.default_observation: Dict = {} -# def observe(self, state: Dict) -> Dict: -# """Generate observation based on the current state of the simulation.""" -# return 0 + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation.""" + return 0 -# @property -# def space(self) -> spaces.Space: -# """Gymnasium space object describing the observation space shape.""" -# return spaces.Discrete(1) + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Discrete(1) -# @classmethod -# def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": -# """ -# Create null observation from a config. + @classmethod + def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": + """ + Create null observation from a config. -# The parameters are ignored, they are here to match the signature of the other observation classes. -# """ -# return cls() + The parameters are ignored, they are here to match the signature of the other observation classes. + """ + return cls() -# class ICSObservation(NullObservation): -# """ICS observation placeholder, currently not implemented so always returns a single 0.""" +class ICSObservation(NullObservation): + """ICS observation placeholder, currently not implemented so always returns a single 0.""" -# pass + pass +''' diff --git a/src/primaite/game/agent/observations/router_observation.py b/src/primaite/game/agent/observations/router_observation.py new file mode 100644 index 00000000..b8dee2c2 --- /dev/null +++ b/src/primaite/game/agent/observations/router_observation.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +from typing import Dict, List, Optional + +from gymnasium import spaces +from gymnasium.core import ObsType + +from primaite import getLogger +from primaite.game.agent.observations.acl_observation import ACLObservation +from primaite.game.agent.observations.nic_observations import PortObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + + +class RouterObservation(AbstractObservation, identifier="ROUTER"): + """Router observation, provides status information about a router within the simulation environment.""" + + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for RouterObservation.""" + + hostname: str + """Hostname of the router, used for querying simulation state dictionary.""" + ports: Optional[List[PortObservation.ConfigSchema]] = None + """Configuration of port observations for this router.""" + num_ports: Optional[int] = None + """Number of port observations configured for this router.""" + acl: Optional[ACLObservation.ConfigSchema] = None + """Configuration of ACL observation on this router.""" + ip_list: Optional[List[str]] = None + """List of IP addresses for encoding ACLs.""" + wildcard_list: Optional[List[str]] = None + """List of IP wildcards for encoding ACLs.""" + port_list: Optional[List[int]] = None + """List of ports for encoding ACLs.""" + protocol_list: Optional[List[str]] = None + """List of protocols for encoding ACLs.""" + num_rules: Optional[int] = None + """Number of rules ACL rules to show.""" + + def __init__( + self, + where: WhereType, + ports: List[PortObservation], + num_ports: int, + acl: ACLObservation, + ) -> None: + """ + Initialize a router observation instance. + + :param where: Where in the simulation state dictionary to find the relevant information for this router. + A typical location for a router might be ['network', 'nodes', ]. + :type where: WhereType + :param ports: List of port observations representing the ports of the router. + :type ports: List[PortObservation] + :param num_ports: Number of ports for the router. + :type num_ports: int + :param acl: ACL observation representing the access control list of the router. + :type acl: ACLObservation + """ + self.where: WhereType = where + self.ports: List[PortObservation] = ports + self.acl: ACLObservation = acl + self.num_ports: int = num_ports + + while len(self.ports) < num_ports: + self.ports.append(PortObservation(where=None)) + while len(self.ports) > num_ports: + self.ports.pop() + msg = "Too many ports in router observation. Truncating." + _LOGGER.warning(msg) + + self.default_observation = { + "PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)}, + "ACL": self.acl.default_observation, + } + + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. + :type state: Dict + :return: Observation containing the status of ports and ACL configuration of the router. + :rtype: ObsType + """ + router_state = access_from_nested_dict(state, self.where) + if router_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)} + obs["ACL"] = self.acl.observe(state) + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for router status. + :rtype: spaces.Space + """ + return spaces.Dict( + {"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}), "ACL": self.acl.space} + ) + + @classmethod + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> RouterObservation: + """ + Create a router observation from a configuration schema. + + :param config: Configuration schema containing the necessary information for the router observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this router's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed router observation instance. + :rtype: RouterObservation + """ + where = parent_where + ["nodes", config.hostname] + + if config.acl is None: + config.acl = ACLObservation.ConfigSchema() + if config.acl.num_rules is None: + config.acl.num_rules = config.num_rules + if config.acl.ip_list is None: + config.acl.ip_list = config.ip_list + if config.acl.wildcard_list is None: + config.acl.wildcard_list = config.wildcard_list + if config.acl.port_list is None: + config.acl.port_list = config.port_list + if config.acl.protocol_list is None: + config.acl.protocol_list = config.protocol_list + + if config.ports is None: + config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)] + + ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] + acl = ACLObservation.from_config(config=config.acl, parent_where=where) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py index 6caf791c..eb94651d 100644 --- a/src/primaite/game/agent/observations/software_observation.py +++ b/src/primaite/game/agent/observations/software_observation.py @@ -1,45 +1,43 @@ -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from __future__ import annotations + +from typing import Dict from gymnasium import spaces +from gymnasium.core import ObsType -from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame +class ServiceObservation(AbstractObservation, identifier="SERVICE"): + """Service observation, shows status of a service in the simulation environment.""" -class ServiceObservation(AbstractObservation): - """Observation of a service in the network.""" + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ServiceObservation.""" - default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} - "Default observation is what should be returned when the service doesn't exist." + service_name: str + """Name of the service, used for querying simulation state dictionary""" - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise service observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'services', ]` - :type where: Optional[List[str]] + def __init__(self, where: WhereType) -> None: """ - super().__init__() - self.where: Optional[Tuple[str]] = where + Initialize a service observation instance. - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. + :param where: Where in the simulation state dictionary to find the relevant information for this service. + A typical location for a service might be ['network', 'nodes', , 'services', ]. + :type where: WhereType + """ + self.where = where + self.default_observation = {"operating_status": 0, "health_status": 0} - :param state: Simulation state dictionary + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. :type state: Dict - :return: Observation - :rtype: Dict + :return: Observation containing the operating status and health status of the service. + :rtype: ObsType """ - if self.where is None: - return self.default_observation - service_state = access_from_nested_dict(state, self.where) if service_state is NOT_PRESENT_IN_STATE: return self.default_observation @@ -50,114 +48,96 @@ class ServiceObservation(AbstractObservation): @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for service status. + :rtype: spaces.Space + """ return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None - ) -> "ServiceObservation": - """Create service observation from a config. + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ServiceObservation: + """ + Create a service observation from a configuration schema. - :param config: Dictionary containing the configuration for this service observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. - :type parent_where: Optional[List[str]], optional - :return: Constructed service observation + :param config: Configuration schema containing the necessary information for the service observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this service's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed service observation instance. :rtype: ServiceObservation """ - return cls(where=parent_where + ["services", config["service_name"]]) + return cls(where=parent_where + ["services", config.service_name]) -class ApplicationObservation(AbstractObservation): - """Observation of an application in the network.""" +class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): + """Application observation, shows the status of an application within the simulation environment.""" - default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0} - "Default observation is what should be returned when the application doesn't exist." + class ConfigSchema(AbstractObservation.ConfigSchema): + """Configuration schema for ApplicationObservation.""" - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise application observation. + application_name: str + """Name of the application, used for querying simulation state dictionary""" - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'applications', ]` - :type where: Optional[List[str]] + def __init__(self, where: WhereType) -> None: """ - super().__init__() - self.where: Optional[Tuple[str]] = where + Initialise an application observation instance. - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. + :param where: Where in the simulation state dictionary to find the relevant information for this application. + A typical location for an application might be + ['network', 'nodes', , 'applications', ]. + :type where: WhereType + """ + self.where = where + self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0} - :param state: Simulation state dictionary + def observe(self, state: Dict) -> ObsType: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary. :type state: Dict - :return: Observation - :rtype: Dict + :return: Obs containing the operating status, health status, and number of executions of the application. + :rtype: ObsType """ - if self.where is None: - return self.default_observation - - app_state = access_from_nested_dict(state, self.where) - if app_state is NOT_PRESENT_IN_STATE: + application_state = access_from_nested_dict(state, self.where) + if application_state is NOT_PRESENT_IN_STATE: return self.default_observation return { - "operating_status": app_state["operating_state"], - "health_status": app_state["health_state_visible"], - "num_executions": self._categorise_num_executions(app_state["num_executions"]), + "operating_status": application_state["operating_state"], + "health_status": application_state["health_state_visible"], + "num_executions": application_state["num_executions"], } @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" + """ + Gymnasium space object describing the observation space shape. + + :return: Gymnasium space representing the observation space for application status. + :rtype: spaces.Space + """ return spaces.Dict( { "operating_status": spaces.Discrete(7), - "health_status": spaces.Discrete(6), + "health_status": spaces.Discrete(5), "num_executions": spaces.Discrete(4), } ) @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None - ) -> "ApplicationObservation": - """Create application observation from a config. + def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ApplicationObservation: + """ + Create an application observation from a configuration schema. - :param config: Dictionary containing the configuration for this service observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. - :type parent_where: Optional[List[str]], optional - :return: Constructed service observation + :param config: Configuration schema containing the necessary information for the application observation. + :type config: ConfigSchema + :param parent_where: Where in the simulation state dictionary to find the information about this application's + parent node. A typical location for a node might be ['network', 'nodes', ]. + :type parent_where: WhereType, optional + :return: Constructed application observation instance. :rtype: ApplicationObservation """ - return cls(where=parent_where + ["services", config["application_name"]]) - - @classmethod - def _categorise_num_executions(cls, num_executions: int) -> int: - """ - Categorise the number of executions of an application. - - Helps classify the number of application executions into different categories. - - Current categories: - - 0: Application is never executed - - 1: Application is executed a low number of times (1-5) - - 2: Application is executed often (6-10) - - 3: Application is executed a high number of times (more than 10) - - :param: num_executions: Number of times the application is executed - """ - if num_executions > 10: - return 3 - elif num_executions > 5: - return 2 - elif num_executions > 0: - return 1 - return 0 + return cls(where=parent_where + ["applications", config.application_name])