diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 5d46b743..b51ea1f2 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -82,7 +82,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"): class FileObservation(AbstractObservation, identifier="FILE"): class ConfigSchema(AbstractObservation.ConfigSchema): file_name: str - include_num_access : bool = False + include_num_access: Optional[bool] = None def __init__(self, where: WhereType, include_num_access: bool)->None: self.where: WhereType = where @@ -118,8 +118,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): class ConfigSchema(AbstractObservation.ConfigSchema): folder_name: str files: List[FileObservation.ConfigSchema] = [] - num_files : int = 0 - include_num_access : bool = False + num_files : Optional[int] = None + include_num_access : Optional[bool] = None def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None: self.where: WhereType = where @@ -179,7 +179,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"): class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"): class ConfigSchema(AbstractObservation.ConfigSchema): nic_num: int - include_nmne: bool = False + include_nmne: Optional[bool] = None def __init__(self, where: WhereType, include_nmne: bool)->None: @@ -233,13 +233,13 @@ class HostObservation(AbstractObservation, identifier="HOST"): applications: List[ApplicationObservation.ConfigSchema] = [] folders: List[FolderObservation.ConfigSchema] = [] network_interfaces: List[NICObservation.ConfigSchema] = [] - num_services: int - num_applications: int - num_folders: int - num_files: int - num_nics: int - include_nmne: bool - include_num_access: bool + num_services: Optional[int] = None + num_applications: Optional[int] = None + num_folders: Optional[int] = None + num_files: Optional[int] = None + num_nics: Optional[int] = None + include_nmne: Optional[bool] = None + include_num_access: Optional[bool] = None def __init__(self, where: WhereType, @@ -296,6 +296,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): self.default_observation: ObsType = { "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)}, "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, "operating_status": 0, @@ -311,6 +312,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): obs = {} obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)} obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} obs["operating_status"] = node_state["operating_state"] obs["NICS"] = { @@ -324,6 +326,7 @@ class HostObservation(AbstractObservation, identifier="HOST"): def space(self) -> spaces.Space: shape = { "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}), "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), "operating_status": spaces.Discrete(5), "NICS": spaces.Dict( @@ -393,15 +396,17 @@ class PortObservation(AbstractObservation, identifier="PORT"): class ACLObservation(AbstractObservation, identifier="ACL"): class ConfigSchema(AbstractObservation.ConfigSchema): - ip_list: List[IPv4Address] - port_list: List[int] - protocol_list: List[str] - num_rules: int + ip_list: Optional[List[IPv4Address]] = None + wildcard_list: Optional[List[str]] = None + port_list: Optional[List[int]] = None + protocol_list: Optional[List[str]] = None + num_rules: Optional[int] = None - def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], port_list: List[int],protocol_list: List[str])->None: + def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], wildcard_list: List[str], port_list: List[int],protocol_list: List[str])->None: self.where = where self.num_rules: int = num_rules self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)} + self.wildcard_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(wildcard_list)} self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)} self.default_observation: Dict = { @@ -409,10 +414,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"): + 1: { "position": i, "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, + "source_ip_id": 0, + "source_wildcard_id": 0, + "source_port_id": 0, + "dest_ip_id": 0, + "dest_wildcard_id": 0, + "dest_port_id": 0, "protocol": 0, } for i in range(self.num_rules) @@ -431,30 +438,38 @@ class ACLObservation(AbstractObservation, identifier="ACL"): obs[i] = { "position": i - 1, "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, + "source_ip_id": 0, + "source_wildcard_id": 0, + "source_port_id": 0, + "dest_ip_id": 0, + "dest_wildcard_id": 0, + "dest_port_id": 0, "protocol": 0, } else: src_ip = rule_state["src_ip_address"] - src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + src_node_id = self.ip_to_id.get(src_ip, 1) dst_ip = rule_state["dst_ip_address"] - dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] - src_port = rule_state["src_port"] - src_port_id = 1 if src_port is None else self.port_to_id[src_port] - dst_port = rule_state["dst_port"] - dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + dst_node_ip = self.ip_to_id.get(dst_ip, 1) + src_wildcard = rule_state["source_wildcard_id"] + src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1) + dst_wildcard = rule_state["dest_wildcard_id"] + dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1) + src_port = rule_state["source_port_id"] + src_port_id = self.port_to_id.get(src_port, 1) + dst_port = rule_state["dest_port_id"] + dst_port_id = self.port_to_id.get(dst_port, 1) protocol = rule_state["protocol"] - protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + protocol_id = self.protocol_to_id.get(protocol, 1) obs[i] = { "position": i - 1, "permission": rule_state["action"], - "source_node_id": src_node_id, - "source_port": src_port_id, - "dest_node_id": dst_node_ip, - "dest_port": dst_port_id, + "source_ip_id": src_node_id, + "source_wildcard_id": src_wildcard_id, + "source_port_id": src_port_id, + "dest_ip_id": dst_node_ip, + "dest_wildcard_id": dst_wildcard_id, + "dest_port_id": dst_port_id, "protocol": protocol_id, } i += 1 @@ -462,7 +477,6 @@ class ACLObservation(AbstractObservation, identifier="ACL"): @property def space(self) -> spaces.Space: - raise NotImplementedError("TODO: need to add wildcard id.") return spaces.Dict( { i @@ -471,10 +485,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"): "position": spaces.Discrete(self.num_rules), "permission": spaces.Discrete(3), # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "source_port": spaces.Discrete(len(self.port_to_id) + 2), - "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), + "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2), + "source_port_id": spaces.Discrete(len(self.port_to_id) + 2), + "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), + "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2), + "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), } ) @@ -489,20 +505,22 @@ class ACLObservation(AbstractObservation, identifier="ACL"): where = parent_where+["acl", "acl"], num_rules = config.num_rules, ip_list = config.ip_list, - ports = config.port_list, - protocols = config.protocol_list + wildcard_list = config.wildcard_list, + port_list = config.port_list, + protocol_list = config.protocol_list ) class RouterObservation(AbstractObservation, identifier="ROUTER"): class ConfigSchema(AbstractObservation.ConfigSchema): hostname: str - ports: List[PortObservation.ConfigSchema] - num_ports: int - acl: ACLObservation.ConfigSchema - ip_list: List[str] - port_list: List[int] - protocol_list: List[str] - num_rules: int + ports: Optional[List[PortObservation.ConfigSchema]] = None + num_ports: Optional[int] = None + acl: Optional[ACLObservation.ConfigSchema] = None + ip_list: Optional[List[str]] = None + wildcard_list: Optional[List[str]] = None + port_list: Optional[List[int]] = None + protocol_list: Optional[List[str]] = None + num_rules: Optional[int] = None def __init__(self, where: WhereType, @@ -540,7 +558,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): @property def space(self) -> spaces.Space: return spaces.Dict({ - "PORTS": {i+1:p.space for i,p in self.ports}, + "PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}), "ACL": self.acl.space }) @@ -548,15 +566,22 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation: where = parent_where + ["nodes", config.hostname] + if config.acl is None: + config.acl = ACLObservation.ConfigSchema() if config.acl.num_rules is None: config.acl.num_rules = config.num_rules if config.acl.ip_list is None: config.acl.ip_list = config.ip_list + if config.acl.wildcard_list is None: + config.acl.wildcard_list = config.wildcard_list if config.acl.port_list is None: config.acl.port_list = config.port_list if config.acl.protocol_list is None: config.acl.protocol_list = config.protocol_list + if config.ports is None: + config.ports = [PortObservation.ConfigSchema(port_id=i+1) for i in range(config.num_ports)] + ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports] acl = ACLObservation.from_config(config=config.acl, parent_where=where) return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl) @@ -564,30 +589,32 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"): class FirewallObservation(AbstractObservation, identifier="FIREWALL"): class ConfigSchema(AbstractObservation.ConfigSchema): hostname: str - ip_list: List[str] - port_list: List[int] - protocol_list: List[str] - num_rules: int + ip_list: Optional[List[str]] = None + wildcard_list: Optional[List[str]] = None + port_list: Optional[List[int]] = None + protocol_list: Optional[List[str]] = None + num_rules: Optional[int] = None def __init__(self, where: WhereType, ip_list: List[str], + wildcard_list: List[str], port_list: List[int], protocol_list: List[str], num_rules: int, )->None: self.where: WhereType = where - self.ports: List[PortObservation] = [PortObservation(where=[self.where+["port", port_num]]) for port_num in (1,2,3) ] + self.ports: List[PortObservation] = [PortObservation(where=self.where+["port", port_num]) for port_num in (1,2,3) ] #TODO: check what the port nums are for firewall. - self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) - self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list) + self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) + self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list) self.default_observation = { @@ -646,7 +673,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"): @classmethod def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation: where = parent_where+["nodes", config.hostname] - return cls(where=where, ip_list=config.ip_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules) + return cls(where=where, ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules) class NodesObservation(AbstractObservation, identifier="NODES"): class ConfigSchema(AbstractObservation.ConfigSchema): @@ -663,7 +690,9 @@ class NodesObservation(AbstractObservation, identifier="NODES"): include_nmne: bool include_num_access: bool + num_ports: int ip_list: List[str] + wildcard_list: List[str] port_list: List[int] protocol_list: List[str] num_rules: int @@ -710,13 +739,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"): if host_config.num_services is None: host_config.num_services = config.num_services if host_config.num_applications is None: - host_config.num_application = config.num_applications + host_config.num_applications = config.num_applications if host_config.num_folders is None: - host_config.num_folder = config.num_folders + host_config.num_folders = config.num_folders if host_config.num_files is None: - host_config.num_file = config.num_files + host_config.num_files = config.num_files if host_config.num_nics is None: - host_config.num_nic = config.num_nics + host_config.num_nics = config.num_nics if host_config.include_nmne is None: host_config.include_nmne = config.include_nmne if host_config.include_num_access is None: @@ -727,26 +756,24 @@ class NodesObservation(AbstractObservation, identifier="NODES"): router_config.num_ports = config.num_ports if router_config.ip_list is None: router_config.ip_list = config.ip_list - + if router_config.wildcard_list is None: + router_config.wildcard_list = config.wildcard_list if router_config.port_list is None: router_config.port_list = config.port_list - if router_config.protocol_list is None: router_config.protocol_list = config.protocol_list - if router_config.num_rules is None: router_config.num_rules = config.num_rules for firewall_config in config.firewalls: if firewall_config.ip_list is None: firewall_config.ip_list = config.ip_list - + if firewall_config.wildcard_list is None: + firewall_config.wildcard_list = config.wildcard_list if firewall_config.port_list is None: firewall_config.port_list = config.port_list - if firewall_config.protocol_list is None: firewall_config.protocol_list = config.protocol_list - if firewall_config.num_rules is None: firewall_config.num_rules = config.num_rules @@ -754,4 +781,4 @@ class NodesObservation(AbstractObservation, identifier="NODES"): routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers] firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls] - cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls) + return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls) diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index d2b47c1a..69ab6a82 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -147,8 +147,10 @@ class ACLRule(SimComponent): state["action"] = self.action.value state["protocol"] = self.protocol.name if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None + state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None state["src_port"] = self.src_port.name if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None + state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None state["dst_port"] = self.dst_port.name if self.dst_port else None state["match_count"] = self.match_count return state