#2417 more observations

This commit is contained in:
Marek Wolan
2024-03-28 17:40:27 +00:00
parent 0d0b5bc7d9
commit f88b4c0f97

View File

@@ -1,4 +1,6 @@
# TODO: make sure when config options are being passed down from higher-level observations to lower-level, but the lower-level also defines that option, don't overwrite.
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
from gymnasium import spaces
@@ -163,7 +165,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
}
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation:
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FolderObservation:
where = parent_where + ["folders", config.folder_name]
#pass down shared/common config items
@@ -220,7 +222,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> NICObservation:
return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne)
@@ -333,7 +335,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> ServiceObservation:
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> HostObservation:
if parent_where is None:
where = ["network", "nodes", config.hostname]
else:
@@ -369,78 +371,282 @@ class HostObservation(AbstractObservation, identifier="HOST"):
class PortObservation(AbstractObservation, identifier="PORT"):
class ConfigSchema(AbstractObservation.ConfigSchema):
pass
port_id : int
def __init__(self, where: WhereType)->None:
pass
self.where = where
self.default_observation: ObsType = {"operating_status" : 0}
def observe(self, state: Dict) -> Any:
pass
port_state = access_from_nested_dict(state, self.where)
if port_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"operating_status": 1 if port_state["enabled"] else 2 }
@property
def space(self) -> spaces.Space:
pass
return spaces.Dict({"operating_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> PortObservation:
return cls(where = parent_where + ["NICs", config.port_id])
class ACLObservation(AbstractObservation, identifier="ACL"):
class ConfigSchema(AbstractObservation.ConfigSchema):
pass
ip_list: List[IPv4Address]
port_list: List[int]
protocol_list: List[str]
num_rules: int
def __init__(self, where: WhereType)->None:
pass
def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], port_list: List[int],protocol_list: List[str])->None:
self.where = where
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)}
self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)}
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Any:
pass
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
pass
raise NotImplementedError("TODO: need to add wildcard id.")
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ACLObservation:
return cls(
where = parent_where+["acl", "acl"],
num_rules = config.num_rules,
ip_list = config.ip_list,
ports = config.port_list,
protocols = config.protocol_list
)
class RouterObservation(AbstractObservation, identifier="ROUTER"):
class ConfigSchema(AbstractObservation.ConfigSchema):
hostname: str
ports: List[PortObservation.ConfigSchema]
num_ports: int
acl: ACLObservation.ConfigSchema
ip_list: List[str]
port_list: List[int]
protocol_list: List[str]
num_rules: int
def __init__(self,
where: WhereType,
ports:List[PortObservation],
num_ports: int,
acl: ACLObservation,
)->None:
self.where: WhereType = where
self.ports: List[PortObservation] = ports
self.acl: ACLObservation = acl
self.num_ports:int = num_ports
def __init__(self, where: WhereType)->None:
pass
while len(self.ports) < num_ports:
self.ports.append(PortObservation(where=None))
while len(self.ports) > num_ports:
self.ports.pop()
msg = f"Too many ports in router observation. Truncating."
_LOGGER.warning(msg)
self.default_observation = {
"PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)},
"ACL": self.acl.default_observation
}
def observe(self, state: Dict) -> Any:
pass
router_state = access_from_nested_dict(state, self.where)
if router_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["PORTS"] = {i+1:p.observe(state) for i,p in enumerate(self.ports)}
obs["ACL"] = self.acl.observe(state)
return obs
@property
def space(self) -> spaces.Space:
pass
return spaces.Dict({
"PORTS": {i+1:p.space for i,p in self.ports},
"ACL": self.acl.space
})
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation:
where = parent_where + ["nodes", config.hostname]
if config.acl.num_rules is None:
config.acl.num_rules = config.num_rules
if config.acl.ip_list is None:
config.acl.ip_list = config.ip_list
if config.acl.port_list is None:
config.acl.port_list = config.port_list
if config.acl.protocol_list is None:
config.acl.protocol_list = config.protocol_list
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
class ConfigSchema(AbstractObservation.ConfigSchema):
hostname: str
ports: List[PortObservation.ConfigSchema] = []
ip_list: List[str]
port_list: List[int]
protocol_list: List[str]
num_rules: int
def __init__(self, where: WhereType)->None:
pass
def __init__(self,
where: WhereType,
ip_list: List[str],
port_list: List[int],
protocol_list: List[str],
num_rules: int,
)->None:
self.where: WhereType = where
self.ports: List[PortObservation] = [PortObservation(where=[self.where+["port", port_num]]) for port_num in (1,2,3) ]
#TODO: check what the port nums are for firewall.
self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.default_observation = {
"PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)},
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.default_observation,
"OUTBOUND": self.internal_outbound_acl.default_observation,
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.default_observation,
"OUTBOUND": self.dmz_outbound_acl.default_observation,
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.default_observation,
"OUTBOUND": self.external_outbound_acl.default_observation,
},
}
def observe(self, state: Dict) -> Any:
pass
obs = {
"PORTS": {i+1:p.observe(state) for i,p in enumerate(self.ports)},
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.observe(state),
"OUTBOUND": self.internal_outbound_acl.observe(state),
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.observe(state),
"OUTBOUND": self.dmz_outbound_acl.observe(state),
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.observe(state),
"OUTBOUND": self.external_outbound_acl.observe(state),
},
}
return obs
@property
def space(self) -> spaces.Space:
pass
space =spaces.Dict({
"PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}),
"INTERNAL": spaces.Dict({
"INBOUND": self.internal_inbound_acl.space,
"OUTBOUND": self.internal_outbound_acl.space,
}),
"DMZ": spaces.Dict({
"INBOUND": self.dmz_inbound_acl.space,
"OUTBOUND": self.dmz_outbound_acl.space,
}),
"EXTERNAL": spaces.Dict({
"INBOUND": self.external_inbound_acl.space,
"OUTBOUND": self.external_outbound_acl.space,
}),
})
return space
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation:
where = parent_where+["nodes", config.hostname]
return cls(where=where, ip_list=config.ip_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules)
class NodesObservation(AbstractObservation, identifier="NODES"):
class ConfigSchema(AbstractObservation.ConfigSchema):
@@ -448,205 +654,104 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
hosts: List[HostObservation.ConfigSchema] = []
routers: List[RouterObservation.ConfigSchema] = []
firewalls: List[FirewallObservation.ConfigSchema] = []
num_services: int = 1
num_services: int
num_applications: int
num_folders: int
num_files: int
num_nics: int
include_nmne: bool
include_num_access: bool
ip_list: List[str]
port_list: List[int]
protocol_list: List[str]
num_rules: int
def __init__(self, where: WhereType)->None:
pass
def __init__(self, where: WhereType, hosts:List[HostObservation], routers:List[RouterObservation], firewalls:List[FirewallObservation])->None:
self.where :WhereType = where
self.hosts: List[HostObservation] = hosts
self.routers: List[RouterObservation] = routers
self.firewalls: List[FirewallObservation] = firewalls
self.default_observation = {
**{f"HOST{i}":host.default_observation for i,host in enumerate(self.hosts)},
**{f"ROUTER{i}":router.default_observation for i,router in enumerate(self.routers)},
**{f"FIREWALL{i}":firewall.default_observation for i,firewall in enumerate(self.firewalls)},
}
def observe(self, state: Dict) -> Any:
pass
@property
def space(self) -> spaces.Space:
pass
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
pass
############################ OLD
class NodeObservation(AbstractObservation, identifier= "OLD"):
"""Observation of a node in the network. Includes services, folders and NICs."""
def __init__(
self,
where: Optional[Tuple[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
network_interfaces: List[NicObservation] = [],
logon_status: bool = False,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> None:
"""
Configurable observation for a node in the simulation.
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service name, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
:type max_services: int, optional
:param folders: Mapping between position in observation space and folder name, defaults to {}
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {}
:type network_interfaces: Dict[int,str], optional
:param max_nics: Max number of network interfaces in this node's obs space, defaults to 5
:type max_nics: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.services: List[ServiceObservation] = services
while len(self.services) < num_services_per_node:
# add empty service observation without `where` parameter so it always returns default (blank) observation
self.services.append(ServiceObservation())
while len(self.services) > num_services_per_node:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warning(msg)
# truncate service list
self.folders: List[FolderObservation] = folders
# add empty folder observation without `where` parameter that will always return default (blank) observations
while len(self.folders) < num_folders_per_node:
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warning(msg)
self.network_interfaces: List[NicObservation] = network_interfaces
while len(self.network_interfaces) < num_nics_per_node:
self.network_interfaces.append(NicObservation())
while len(self.network_interfaces) > num_nics_per_node:
truncated_nic = self.network_interfaces.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warning(msg)
self.logon_status: bool = logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
obs = {
**{f"HOST{i}":host.observe(state) for i,host in enumerate(self.hosts)},
**{f"ROUTER{i}":router.observe(state) for i,router in enumerate(self.routers)},
**{f"FIREWALL{i}":firewall.observe(state) for i,firewall in enumerate(self.firewalls)},
}
if self.logon_status:
self.default_observation["logon_status"] = 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
}
if self.logon_status:
obs["logon_status"] = 0
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
space_shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
}
if self.logon_status:
space_shape["logon_status"] = spaces.Discrete(3)
return spaces.Dict(space_shape)
space = spaces.Dict({
**{f"HOST{i}":host.space for i,host in enumerate(self.hosts)},
**{f"ROUTER{i}":router.space for i,router in enumerate(self.routers)},
**{f"FIREWALL{i}":firewall.space for i,firewall in enumerate(self.firewalls)},
})
return space
@classmethod
def from_config(
cls,
config: Dict,
game: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> "NodeObservation":
"""Create node observation from a config. Also creates child service, folder and NIC observations.
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
observation size) , defaults to 2
:type num_services_per_node: int, optional
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
observation size) , defaults to 2
:type num_folders_per_node: int, optional
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed node observation
:rtype: NodeObservation
"""
node_hostname = config["node_hostname"]
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
if parent_where is None:
where = ["network", "nodes", node_hostname]
where = ["network", "nodes"]
else:
where = parent_where + ["nodes", node_hostname]
where = parent_where + ["nodes"]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
services=services,
folders=folders,
network_interfaces=network_interfaces,
logon_status=logon_status,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
for host_config in config.hosts:
if host_config.num_services is None:
host_config.num_services = config.num_services
if host_config.num_applications is None:
host_config.num_application = config.num_applications
if host_config.num_folders is None:
host_config.num_folder = config.num_folders
if host_config.num_files is None:
host_config.num_file = config.num_files
if host_config.num_nics is None:
host_config.num_nic = config.num_nics
if host_config.include_nmne is None:
host_config.include_nmne = config.include_nmne
if host_config.include_num_access is None:
host_config.include_num_access = config.include_num_access
for router_config in config.routers:
if router_config.num_ports is None:
router_config.num_ports = config.num_ports
if router_config.ip_list is None:
router_config.ip_list = config.ip_list
if router_config.port_list is None:
router_config.port_list = config.port_list
if router_config.protocol_list is None:
router_config.protocol_list = config.protocol_list
if router_config.num_rules is None:
router_config.num_rules = config.num_rules
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
firewall_config.ip_list = config.ip_list
if firewall_config.port_list is None:
firewall_config.port_list = config.port_list
if firewall_config.protocol_list is None:
firewall_config.protocol_list = config.protocol_list
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)