From f88b4c0f97716ff03344ae22d732252733749c58 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 28 Mar 2024 17:40:27 +0000 Subject: [PATCH] #2417 more observations --- .../agent/observations/node_observations.py | 539 +++++++++++------- 1 file changed, 322 insertions(+), 217 deletions(-) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 42bdb749..5d46b743 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,4 +1,6 @@ +# TODO: make sure when config options are being passed down from higher-level observations to lower-level, but the lower-level also defines that option, don't overwrite. from __future__ import annotations +from ipaddress import IPv4Address from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union from gymnasium import spaces @@ -163,7 +165,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): } ) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FolderObservation: where = parent_where + ["folders", config.folder_name] #pass down shared/common config items @@ -220,7 +222,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> NICObservation: return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne) @@ -333,7 +335,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): return spaces.Dict(shape) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> ServiceObservation: + def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> HostObservation: if parent_where is None: where = ["network", "nodes", config.hostname] else: @@ -369,78 +371,282 @@ class HostObservation(AbstractObservation, identifier="HOST"): class PortObservation(AbstractObservation, identifier="PORT"): class ConfigSchema(AbstractObservation.ConfigSchema): - pass + port_id : int def __init__(self, where: WhereType)->None: - pass + self.where = where + self.default_observation: ObsType = {"operating_status" : 0} def observe(self, state: Dict) -> Any: - pass + port_state = access_from_nested_dict(state, self.where) + if port_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"operating_status": 1 if port_state["enabled"] else 2 } @property def space(self) -> spaces.Space: - pass + return spaces.Dict({"operating_status": spaces.Discrete(3)}) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> PortObservation: + return cls(where = parent_where + ["NICs", config.port_id]) class ACLObservation(AbstractObservation, identifier="ACL"): class ConfigSchema(AbstractObservation.ConfigSchema): - pass + ip_list: List[IPv4Address] + port_list: List[int] + protocol_list: List[str] + num_rules: int - def __init__(self, where: WhereType)->None: - pass + def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], port_list: List[int],protocol_list: List[str])->None: + self.where = where + self.num_rules: int = num_rules + self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)} + self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)} + self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)} + self.default_observation: Dict = { + i + + 1: { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + for i in range(self.num_rules) + } def observe(self, state: Dict) -> Any: - pass + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation + obs = {} + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] + if rule_state is None: + obs[i] = { + "position": i - 1, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + else: + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, + "permission": rule_state["action"], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, + } + i += 1 + return obs @property def space(self) -> spaces.Space: - pass + raise NotImplementedError("TODO: need to add wildcard id.") + return spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "source_port": spaces.Discrete(len(self.port_to_id) + 2), + "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), + } + ) + for i in range(self.num_rules) + } + ) + @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ACLObservation: + return cls( + where = parent_where+["acl", "acl"], + num_rules = config.num_rules, + ip_list = config.ip_list, + ports = config.port_list, + protocols = config.protocol_list + ) class RouterObservation(AbstractObservation, identifier="ROUTER"): class ConfigSchema(AbstractObservation.ConfigSchema): hostname: str ports: List[PortObservation.ConfigSchema] + num_ports: int + acl: ACLObservation.ConfigSchema + ip_list: List[str] + port_list: List[int] + protocol_list: List[str] + num_rules: int + def __init__(self, + where: WhereType, + ports:List[PortObservation], + num_ports: int, + acl: ACLObservation, + )->None: + self.where: WhereType = where + self.ports: List[PortObservation] = ports + self.acl: ACLObservation = acl + self.num_ports:int = num_ports - def __init__(self, where: WhereType)->None: - pass + while len(self.ports) < num_ports: + self.ports.append(PortObservation(where=None)) + while len(self.ports) > num_ports: + self.ports.pop() + msg = f"Too many ports in router observation. Truncating." + _LOGGER.warning(msg) + + self.default_observation = { + "PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)}, + "ACL": self.acl.default_observation + } def observe(self, state: Dict) -> Any: - pass + router_state = access_from_nested_dict(state, self.where) + if router_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["PORTS"] = {i+1:p.observe(state) for i,p in enumerate(self.ports)} + obs["ACL"] = self.acl.observe(state) + return obs @property def space(self) -> spaces.Space: - pass + return spaces.Dict({ + "PORTS": {i+1:p.space for i,p in self.ports}, + "ACL": self.acl.space + }) @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation: + where = parent_where + ["nodes", config.hostname] + + if config.acl.num_rules is None: + config.acl.num_rules = config.num_rules + if config.acl.ip_list is None: + config.acl.ip_list = config.ip_list + if config.acl.port_list is None: + config.acl.port_list = config.port_list + if config.acl.protocol_list is None: + config.acl.protocol_list = config.protocol_list + + ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] + acl = ACLObservation.from_config(config=config.acl, parent_where=where) + return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) class FirewallObservation(AbstractObservation, identifier="FIREWALL"): class ConfigSchema(AbstractObservation.ConfigSchema): hostname: str - ports: List[PortObservation.ConfigSchema] = [] + ip_list: List[str] + port_list: List[int] + protocol_list: List[str] + num_rules: int - def __init__(self, where: WhereType)->None: - pass + + def __init__(self, + where: WhereType, + ip_list: List[str], + port_list: List[int], + protocol_list: List[str], + num_rules: int, + )->None: + self.where: WhereType = where + + self.ports: List[PortObservation] = [PortObservation(where=[self.where+["port", port_num]]) for port_num in (1,2,3) ] + #TODO: check what the port nums are for firewall. + + self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + + + self.default_observation = { + "PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)}, + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.default_observation, + "OUTBOUND": self.internal_outbound_acl.default_observation, + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.default_observation, + "OUTBOUND": self.dmz_outbound_acl.default_observation, + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.default_observation, + "OUTBOUND": self.external_outbound_acl.default_observation, + }, + } def observe(self, state: Dict) -> Any: - pass + obs = { + "PORTS": {i+1:p.observe(state) for i,p in enumerate(self.ports)}, + "INTERNAL": { + "INBOUND": self.internal_inbound_acl.observe(state), + "OUTBOUND": self.internal_outbound_acl.observe(state), + }, + "DMZ": { + "INBOUND": self.dmz_inbound_acl.observe(state), + "OUTBOUND": self.dmz_outbound_acl.observe(state), + }, + "EXTERNAL": { + "INBOUND": self.external_inbound_acl.observe(state), + "OUTBOUND": self.external_outbound_acl.observe(state), + }, + } + return obs @property def space(self) -> spaces.Space: - pass + space =spaces.Dict({ + "PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}), + "INTERNAL": spaces.Dict({ + "INBOUND": self.internal_inbound_acl.space, + "OUTBOUND": self.internal_outbound_acl.space, + }), + "DMZ": spaces.Dict({ + "INBOUND": self.dmz_inbound_acl.space, + "OUTBOUND": self.dmz_outbound_acl.space, + }), + "EXTERNAL": spaces.Dict({ + "INBOUND": self.external_inbound_acl.space, + "OUTBOUND": self.external_outbound_acl.space, + }), + }) + return space @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation: + where = parent_where+["nodes", config.hostname] + return cls(where=where, ip_list=config.ip_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules) class NodesObservation(AbstractObservation, identifier="NODES"): class ConfigSchema(AbstractObservation.ConfigSchema): @@ -448,205 +654,104 @@ class NodesObservation(AbstractObservation, identifier="NODES"): hosts: List[HostObservation.ConfigSchema] = [] routers: List[RouterObservation.ConfigSchema] = [] firewalls: List[FirewallObservation.ConfigSchema] = [] - num_services: int = 1 + + num_services: int + num_applications: int + num_folders: int + num_files: int + num_nics: int + include_nmne: bool + include_num_access: bool + + ip_list: List[str] + port_list: List[int] + protocol_list: List[str] + num_rules: int - def __init__(self, where: WhereType)->None: - pass + def __init__(self, where: WhereType, hosts:List[HostObservation], routers:List[RouterObservation], firewalls:List[FirewallObservation])->None: + self.where :WhereType = where + + self.hosts: List[HostObservation] = hosts + self.routers: List[RouterObservation] = routers + self.firewalls: List[FirewallObservation] = firewalls + + self.default_observation = { + **{f"HOST{i}":host.default_observation for i,host in enumerate(self.hosts)}, + **{f"ROUTER{i}":router.default_observation for i,router in enumerate(self.routers)}, + **{f"FIREWALL{i}":firewall.default_observation for i,firewall in enumerate(self.firewalls)}, + } def observe(self, state: Dict) -> Any: - pass - - @property - def space(self) -> spaces.Space: - pass - - @classmethod - def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: - pass - -############################ OLD - -class NodeObservation(AbstractObservation, identifier= "OLD"): - """Observation of a node in the network. Includes services, folders and NICs.""" - - def __init__( - self, - where: Optional[Tuple[str]] = None, - services: List[ServiceObservation] = [], - folders: List[FolderObservation] = [], - network_interfaces: List[NicObservation] = [], - logon_status: bool = False, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> None: - """ - Configurable observation for a node in the simulation. - - :param where: Where in the simulation state dictionary for find relevant information for this observation. - A typical location for a node looks like this: - ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] - :type where: List[str], optional - :param services: Mapping between position in observation space and service name, defaults to {} - :type services: Dict[int,str], optional - :param max_services: Max number of services that can be presented in observation space for this node - , defaults to 2 - :type max_services: int, optional - :param folders: Mapping between position in observation space and folder name, defaults to {} - :type folders: Dict[int,str], optional - :param max_folders: Max number of folders in this node's obs space, defaults to 2 - :type max_folders: int, optional - :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} - :type network_interfaces: Dict[int,str], optional - :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 - :type max_nics: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.services: List[ServiceObservation] = services - while len(self.services) < num_services_per_node: - # add empty service observation without `where` parameter so it always returns default (blank) observation - self.services.append(ServiceObservation()) - while len(self.services) > num_services_per_node: - truncated_service = self.services.pop() - msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warning(msg) - # truncate service list - - self.folders: List[FolderObservation] = folders - # add empty folder observation without `where` parameter that will always return default (blank) observations - while len(self.folders) < num_folders_per_node: - self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) - while len(self.folders) > num_folders_per_node: - truncated_folder = self.folders.pop() - msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" - _LOGGER.warning(msg) - - self.network_interfaces: List[NicObservation] = network_interfaces - while len(self.network_interfaces) < num_nics_per_node: - self.network_interfaces.append(NicObservation()) - while len(self.network_interfaces) > num_nics_per_node: - truncated_nic = self.network_interfaces.pop() - msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" - _LOGGER.warning(msg) - - self.logon_status: bool = logon_status - - self.default_observation: Dict = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, - "operating_status": 0, + obs = { + **{f"HOST{i}":host.observe(state) for i,host in enumerate(self.hosts)}, + **{f"ROUTER{i}":router.observe(state) for i,router in enumerate(self.routers)}, + **{f"FIREWALL{i}":firewall.observe(state) for i,firewall in enumerate(self.firewalls)}, } - if self.logon_status: - self.default_observation["logon_status"] = 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - node_state = access_from_nested_dict(state, self.where) - if node_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - obs["operating_status"] = node_state["operating_state"] - obs["NICS"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } - - if self.logon_status: - obs["logon_status"] = 0 - return obs @property def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - space_shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(5), - "NICS": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), - } - if self.logon_status: - space_shape["logon_status"] = spaces.Discrete(3) - - return spaces.Dict(space_shape) + space = spaces.Dict({ + **{f"HOST{i}":host.space for i,host in enumerate(self.hosts)}, + **{f"ROUTER{i}":router.space for i,router in enumerate(self.routers)}, + **{f"FIREWALL{i}":firewall.space for i,firewall in enumerate(self.firewalls)}, + }) + return space @classmethod - def from_config( - cls, - config: Dict, - game: "PrimaiteGame", - parent_where: Optional[List[str]] = None, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> "NodeObservation": - """Create node observation from a config. Also creates child service, folder and NIC observations. - - :param config: Dictionary containing the configuration for this node observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this node's parent - network. A typical location for it would be: ['network',] - :type parent_where: Optional[List[str]] - :param num_services_per_node: How many spaces for services are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_services_per_node: int, optional - :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_folders_per_node: int, optional - :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed node observation - :rtype: NodeObservation - """ - node_hostname = config["node_hostname"] + def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation: if parent_where is None: - where = ["network", "nodes", node_hostname] + where = ["network", "nodes"] else: - where = parent_where + ["nodes", node_hostname] + where = parent_where + ["nodes"] - svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] - folder_configs = config.get("folders", {}) - folders = [ - FolderObservation.from_config( - config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder - ) - for c in folder_configs - ] - # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. - nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] - network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] - logon_status = config.get("logon_status", False) - return cls( - where=where, - services=services, - folders=folders, - network_interfaces=network_interfaces, - logon_status=logon_status, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) + for host_config in config.hosts: + if host_config.num_services is None: + host_config.num_services = config.num_services + if host_config.num_applications is None: + host_config.num_application = config.num_applications + if host_config.num_folders is None: + host_config.num_folder = config.num_folders + if host_config.num_files is None: + host_config.num_file = config.num_files + if host_config.num_nics is None: + host_config.num_nic = config.num_nics + if host_config.include_nmne is None: + host_config.include_nmne = config.include_nmne + if host_config.include_num_access is None: + host_config.include_num_access = config.include_num_access + + for router_config in config.routers: + if router_config.num_ports is None: + router_config.num_ports = config.num_ports + if router_config.ip_list is None: + router_config.ip_list = config.ip_list + + if router_config.port_list is None: + router_config.port_list = config.port_list + + if router_config.protocol_list is None: + router_config.protocol_list = config.protocol_list + + if router_config.num_rules is None: + router_config.num_rules = config.num_rules + + for firewall_config in config.firewalls: + if firewall_config.ip_list is None: + firewall_config.ip_list = config.ip_list + + if firewall_config.port_list is None: + firewall_config.port_list = config.port_list + + if firewall_config.protocol_list is None: + firewall_config.protocol_list = config.protocol_list + + if firewall_config.num_rules is None: + firewall_config.num_rules = config.num_rules + + hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts] + routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] + firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls] + + cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)