#2417 more observations
This commit is contained in:
@@ -1,4 +1,6 @@
|
||||
# TODO: make sure when config options are being passed down from higher-level observations to lower-level, but the lower-level also defines that option, don't overwrite.
|
||||
from __future__ import annotations
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
from gymnasium import spaces
|
||||
@@ -163,7 +165,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
}
|
||||
)
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FileObservation:
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FolderObservation:
|
||||
where = parent_where + ["folders", config.folder_name]
|
||||
|
||||
#pass down shared/common config items
|
||||
@@ -220,7 +222,7 @@ class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> NICObservation:
|
||||
return cls(where = parent_where+["NICs", config.nic_num], include_nmne=config.include_nmne)
|
||||
|
||||
|
||||
@@ -333,7 +335,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> ServiceObservation:
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = None ) -> HostObservation:
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", config.hostname]
|
||||
else:
|
||||
@@ -369,78 +371,282 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
|
||||
class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
pass
|
||||
port_id : int
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"operating_status" : 0}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
port_state = access_from_nested_dict(state, self.where)
|
||||
if port_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"operating_status": 1 if port_state["enabled"] else 2 }
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> PortObservation:
|
||||
return cls(where = parent_where + ["NICs", config.port_id])
|
||||
|
||||
class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
pass
|
||||
ip_list: List[IPv4Address]
|
||||
port_list: List[int]
|
||||
protocol_list: List[str]
|
||||
num_rules: int
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], port_list: List[int],protocol_list: List[str])->None:
|
||||
self.where = where
|
||||
self.num_rules: int = num_rules
|
||||
self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)}
|
||||
self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
raise NotImplementedError("TODO: need to add wildcard id.")
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ACLObservation:
|
||||
return cls(
|
||||
where = parent_where+["acl", "acl"],
|
||||
num_rules = config.num_rules,
|
||||
ip_list = config.ip_list,
|
||||
ports = config.port_list,
|
||||
protocols = config.protocol_list
|
||||
)
|
||||
|
||||
class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
hostname: str
|
||||
ports: List[PortObservation.ConfigSchema]
|
||||
num_ports: int
|
||||
acl: ACLObservation.ConfigSchema
|
||||
ip_list: List[str]
|
||||
port_list: List[int]
|
||||
protocol_list: List[str]
|
||||
num_rules: int
|
||||
|
||||
def __init__(self,
|
||||
where: WhereType,
|
||||
ports:List[PortObservation],
|
||||
num_ports: int,
|
||||
acl: ACLObservation,
|
||||
)->None:
|
||||
self.where: WhereType = where
|
||||
self.ports: List[PortObservation] = ports
|
||||
self.acl: ACLObservation = acl
|
||||
self.num_ports:int = num_ports
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
while len(self.ports) < num_ports:
|
||||
self.ports.append(PortObservation(where=None))
|
||||
while len(self.ports) > num_ports:
|
||||
self.ports.pop()
|
||||
msg = f"Too many ports in router observation. Truncating."
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)},
|
||||
"ACL": self.acl.default_observation
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
router_state = access_from_nested_dict(state, self.where)
|
||||
if router_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["PORTS"] = {i+1:p.observe(state) for i,p in enumerate(self.ports)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
return spaces.Dict({
|
||||
"PORTS": {i+1:p.space for i,p in self.ports},
|
||||
"ACL": self.acl.space
|
||||
})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation:
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
|
||||
if config.acl.num_rules is None:
|
||||
config.acl.num_rules = config.num_rules
|
||||
if config.acl.ip_list is None:
|
||||
config.acl.ip_list = config.ip_list
|
||||
if config.acl.port_list is None:
|
||||
config.acl.port_list = config.port_list
|
||||
if config.acl.protocol_list is None:
|
||||
config.acl.protocol_list = config.protocol_list
|
||||
|
||||
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
|
||||
|
||||
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
hostname: str
|
||||
ports: List[PortObservation.ConfigSchema] = []
|
||||
ip_list: List[str]
|
||||
port_list: List[int]
|
||||
protocol_list: List[str]
|
||||
num_rules: int
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
|
||||
def __init__(self,
|
||||
where: WhereType,
|
||||
ip_list: List[str],
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
)->None:
|
||||
self.where: WhereType = where
|
||||
|
||||
self.ports: List[PortObservation] = [PortObservation(where=[self.where+["port", port_num]]) for port_num in (1,2,3) ]
|
||||
#TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i+1:p.default_observation for i,p in enumerate(self.ports)},
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.internal_outbound_acl.default_observation,
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.dmz_outbound_acl.default_observation,
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.external_outbound_acl.default_observation,
|
||||
},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
obs = {
|
||||
"PORTS": {i+1:p.observe(state) for i,p in enumerate(self.ports)},
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
space =spaces.Dict({
|
||||
"PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}),
|
||||
"INTERNAL": spaces.Dict({
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}),
|
||||
"DMZ": spaces.Dict({
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}),
|
||||
"EXTERNAL": spaces.Dict({
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}),
|
||||
})
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation:
|
||||
where = parent_where+["nodes", config.hostname]
|
||||
return cls(where=where, ip_list=config.ip_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules)
|
||||
|
||||
class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
@@ -448,205 +654,104 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
hosts: List[HostObservation.ConfigSchema] = []
|
||||
routers: List[RouterObservation.ConfigSchema] = []
|
||||
firewalls: List[FirewallObservation.ConfigSchema] = []
|
||||
num_services: int = 1
|
||||
|
||||
num_services: int
|
||||
num_applications: int
|
||||
num_folders: int
|
||||
num_files: int
|
||||
num_nics: int
|
||||
include_nmne: bool
|
||||
include_num_access: bool
|
||||
|
||||
ip_list: List[str]
|
||||
port_list: List[int]
|
||||
protocol_list: List[str]
|
||||
num_rules: int
|
||||
|
||||
|
||||
def __init__(self, where: WhereType)->None:
|
||||
pass
|
||||
def __init__(self, where: WhereType, hosts:List[HostObservation], routers:List[RouterObservation], firewalls:List[FirewallObservation])->None:
|
||||
self.where :WhereType = where
|
||||
|
||||
self.hosts: List[HostObservation] = hosts
|
||||
self.routers: List[RouterObservation] = routers
|
||||
self.firewalls: List[FirewallObservation] = firewalls
|
||||
|
||||
self.default_observation = {
|
||||
**{f"HOST{i}":host.default_observation for i,host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}":router.default_observation for i,router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}":firewall.default_observation for i,firewall in enumerate(self.firewalls)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
pass
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
pass
|
||||
|
||||
############################ OLD
|
||||
|
||||
class NodeObservation(AbstractObservation, identifier= "OLD"):
|
||||
"""Observation of a node in the network. Includes services, folders and NICs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
network_interfaces: List[NicObservation] = [],
|
||||
logon_status: bool = False,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
|
||||
:param where: Where in the simulation state dictionary for find relevant information for this observation.
|
||||
A typical location for a node looks like this:
|
||||
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
|
||||
:type where: List[str], optional
|
||||
:param services: Mapping between position in observation space and service name, defaults to {}
|
||||
:type services: Dict[int,str], optional
|
||||
:param max_services: Max number of services that can be presented in observation space for this node
|
||||
, defaults to 2
|
||||
:type max_services: int, optional
|
||||
:param folders: Mapping between position in observation space and folder name, defaults to {}
|
||||
:type folders: Dict[int,str], optional
|
||||
:param max_folders: Max number of folders in this node's obs space, defaults to 2
|
||||
:type max_folders: int, optional
|
||||
:param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {}
|
||||
:type network_interfaces: Dict[int,str], optional
|
||||
:param max_nics: Max number of network interfaces in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services_per_node:
|
||||
# add empty service observation without `where` parameter so it always returns default (blank) observation
|
||||
self.services.append(ServiceObservation())
|
||||
while len(self.services) > num_services_per_node:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warning(msg)
|
||||
# truncate service list
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
# add empty folder observation without `where` parameter that will always return default (blank) observations
|
||||
while len(self.folders) < num_folders_per_node:
|
||||
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
|
||||
while len(self.folders) > num_folders_per_node:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.network_interfaces: List[NicObservation] = network_interfaces
|
||||
while len(self.network_interfaces) < num_nics_per_node:
|
||||
self.network_interfaces.append(NicObservation())
|
||||
while len(self.network_interfaces) > num_nics_per_node:
|
||||
truncated_nic = self.network_interfaces.pop()
|
||||
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
obs = {
|
||||
**{f"HOST{i}":host.observe(state) for i,host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}":router.observe(state) for i,router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}":firewall.observe(state) for i,firewall in enumerate(self.firewalls)},
|
||||
}
|
||||
if self.logon_status:
|
||||
self.default_observation["logon_status"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {
|
||||
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
|
||||
}
|
||||
|
||||
if self.logon_status:
|
||||
obs["logon_status"] = 0
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
space_shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict(
|
||||
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
|
||||
),
|
||||
}
|
||||
if self.logon_status:
|
||||
space_shape["logon_status"] = spaces.Discrete(3)
|
||||
|
||||
return spaces.Dict(space_shape)
|
||||
space = spaces.Dict({
|
||||
**{f"HOST{i}":host.space for i,host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}":router.space for i,router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}":firewall.space for i,firewall in enumerate(self.firewalls)},
|
||||
})
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
game: "PrimaiteGame",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> "NodeObservation":
|
||||
"""Create node observation from a config. Also creates child service, folder and NIC observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this node observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
|
||||
network. A typical location for it would be: ['network',]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_services_per_node: int, optional
|
||||
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_folders_per_node: int, optional
|
||||
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
"""
|
||||
node_hostname = config["node_hostname"]
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> ServiceObservation:
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_hostname]
|
||||
where = ["network", "nodes"]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_hostname]
|
||||
where = parent_where + ["nodes"]
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
|
||||
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
|
||||
network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
folders=folders,
|
||||
network_interfaces=network_interfaces,
|
||||
logon_status=logon_status,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
for host_config in config.hosts:
|
||||
if host_config.num_services is None:
|
||||
host_config.num_services = config.num_services
|
||||
if host_config.num_applications is None:
|
||||
host_config.num_application = config.num_applications
|
||||
if host_config.num_folders is None:
|
||||
host_config.num_folder = config.num_folders
|
||||
if host_config.num_files is None:
|
||||
host_config.num_file = config.num_files
|
||||
if host_config.num_nics is None:
|
||||
host_config.num_nic = config.num_nics
|
||||
if host_config.include_nmne is None:
|
||||
host_config.include_nmne = config.include_nmne
|
||||
if host_config.include_num_access is None:
|
||||
host_config.include_num_access = config.include_num_access
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
router_config.num_ports = config.num_ports
|
||||
if router_config.ip_list is None:
|
||||
router_config.ip_list = config.ip_list
|
||||
|
||||
if router_config.port_list is None:
|
||||
router_config.port_list = config.port_list
|
||||
|
||||
if router_config.protocol_list is None:
|
||||
router_config.protocol_list = config.protocol_list
|
||||
|
||||
if router_config.num_rules is None:
|
||||
router_config.num_rules = config.num_rules
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
firewall_config.ip_list = config.ip_list
|
||||
|
||||
if firewall_config.port_list is None:
|
||||
firewall_config.port_list = config.port_list
|
||||
|
||||
if firewall_config.protocol_list is None:
|
||||
firewall_config.protocol_list = config.protocol_list
|
||||
|
||||
if firewall_config.num_rules is None:
|
||||
firewall_config.num_rules = config.num_rules
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
|
||||
|
||||
cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
|
||||
|
||||
Reference in New Issue
Block a user