Fixed observations

This commit is contained in:
Marek Wolan
2024-03-29 11:55:22 +00:00
parent f88b4c0f97
commit d8a66104f5
2 changed files with 102 additions and 73 deletions

View File

@@ -82,7 +82,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
class FileObservation(AbstractObservation, identifier="FILE"):
class ConfigSchema(AbstractObservation.ConfigSchema):
file_name: str
include_num_access : bool = False
include_num_access: Optional[bool] = None
def __init__(self, where: WhereType, include_num_access: bool)->None:
self.where: WhereType = where
@@ -118,8 +118,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
class ConfigSchema(AbstractObservation.ConfigSchema):
folder_name: str
files: List[FileObservation.ConfigSchema] = []
num_files : int = 0
include_num_access : bool = False
num_files : Optional[int] = None
include_num_access : Optional[bool] = None
def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None:
self.where: WhereType = where
@@ -179,7 +179,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
class ConfigSchema(AbstractObservation.ConfigSchema):
nic_num: int
include_nmne: bool = False
include_nmne: Optional[bool] = None
def __init__(self, where: WhereType, include_nmne: bool)->None:
@@ -233,13 +233,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
applications: List[ApplicationObservation.ConfigSchema] = []
folders: List[FolderObservation.ConfigSchema] = []
network_interfaces: List[NICObservation.ConfigSchema] = []
num_services: int
num_applications: int
num_folders: int
num_files: int
num_nics: int
include_nmne: bool
include_num_access: bool
num_services: Optional[int] = None
num_applications: Optional[int] = None
num_folders: Optional[int] = None
num_files: Optional[int] = None
num_nics: Optional[int] = None
include_nmne: Optional[bool] = None
include_num_access: Optional[bool] = None
def __init__(self,
where: WhereType,
@@ -296,6 +296,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
self.default_observation: ObsType = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
@@ -311,6 +312,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
@@ -324,6 +326,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
def space(self) -> spaces.Space:
shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
@@ -393,15 +396,17 @@ class PortObservation(AbstractObservation, identifier="PORT"):
class ACLObservation(AbstractObservation, identifier="ACL"):
class ConfigSchema(AbstractObservation.ConfigSchema):
ip_list: List[IPv4Address]
port_list: List[int]
protocol_list: List[str]
num_rules: int
ip_list: Optional[List[IPv4Address]] = None
wildcard_list: Optional[List[str]] = None
port_list: Optional[List[int]] = None
protocol_list: Optional[List[str]] = None
num_rules: Optional[int] = None
def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], port_list: List[int],protocol_list: List[str])->None:
def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], wildcard_list: List[str], port_list: List[int],protocol_list: List[str])->None:
self.where = where
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)}
self.default_observation: Dict = {
@@ -409,10 +414,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"source_ip_id": 0,
"source_wildcard_id": 0,
"source_port_id": 0,
"dest_ip_id": 0,
"dest_wildcard_id": 0,
"dest_port_id": 0,
"protocol": 0,
}
for i in range(self.num_rules)
@@ -431,30 +438,38 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"source_ip_id": 0,
"source_wildcard_id": 0,
"source_port_id": 0,
"dest_ip_id": 0,
"dest_wildcard_id": 0,
"dest_port_id": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
src_node_id = self.ip_to_id.get(src_ip, 1)
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
dst_node_ip = self.ip_to_id.get(dst_ip, 1)
src_wildcard = rule_state["source_wildcard_id"]
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
dst_wildcard = rule_state["dest_wildcard_id"]
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
src_port = rule_state["source_port_id"]
src_port_id = self.port_to_id.get(src_port, 1)
dst_port = rule_state["dest_port_id"]
dst_port_id = self.port_to_id.get(dst_port, 1)
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
protocol_id = self.protocol_to_id.get(protocol, 1)
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"source_ip_id": src_node_id,
"source_wildcard_id": src_wildcard_id,
"source_port_id": src_port_id,
"dest_ip_id": dst_node_ip,
"dest_wildcard_id": dst_wildcard_id,
"dest_port_id": dst_port_id,
"protocol": protocol_id,
}
i += 1
@@ -462,7 +477,6 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
@property
def space(self) -> spaces.Space:
raise NotImplementedError("TODO: need to add wildcard id.")
return spaces.Dict(
{
i
@@ -471,10 +485,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2),
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2),
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
@@ -489,20 +505,22 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
where = parent_where+["acl", "acl"],
num_rules = config.num_rules,
ip_list = config.ip_list,
ports = config.port_list,
protocols = config.protocol_list
wildcard_list = config.wildcard_list,
port_list = config.port_list,
protocol_list = config.protocol_list
)
class RouterObservation(AbstractObservation, identifier="ROUTER"):
class ConfigSchema(AbstractObservation.ConfigSchema):
hostname: str
ports: List[PortObservation.ConfigSchema]
num_ports: int
acl: ACLObservation.ConfigSchema
ip_list: List[str]
port_list: List[int]
protocol_list: List[str]
num_rules: int
ports: Optional[List[PortObservation.ConfigSchema]] = None
num_ports: Optional[int] = None
acl: Optional[ACLObservation.ConfigSchema] = None
ip_list: Optional[List[str]] = None
wildcard_list: Optional[List[str]] = None
port_list: Optional[List[int]] = None
protocol_list: Optional[List[str]] = None
num_rules: Optional[int] = None
def __init__(self,
where: WhereType,
@@ -540,7 +558,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
@property
def space(self) -> spaces.Space:
return spaces.Dict({
"PORTS": {i+1:p.space for i,p in self.ports},
"PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}),
"ACL": self.acl.space
})
@@ -548,15 +566,22 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation:
where = parent_where + ["nodes", config.hostname]
if config.acl is None:
config.acl = ACLObservation.ConfigSchema()
if config.acl.num_rules is None:
config.acl.num_rules = config.num_rules
if config.acl.ip_list is None:
config.acl.ip_list = config.ip_list
if config.acl.wildcard_list is None:
config.acl.wildcard_list = config.wildcard_list
if config.acl.port_list is None:
config.acl.port_list = config.port_list
if config.acl.protocol_list is None:
config.acl.protocol_list = config.protocol_list
if config.ports is None:
config.ports = [PortObservation.ConfigSchema(port_id=i+1) for i in range(config.num_ports)]
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
@@ -564,30 +589,32 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
class ConfigSchema(AbstractObservation.ConfigSchema):
hostname: str
ip_list: List[str]
port_list: List[int]
protocol_list: List[str]
num_rules: int
ip_list: Optional[List[str]] = None
wildcard_list: Optional[List[str]] = None
port_list: Optional[List[int]] = None
protocol_list: Optional[List[str]] = None
num_rules: Optional[int] = None
def __init__(self,
where: WhereType,
ip_list: List[str],
wildcard_list: List[str],
port_list: List[int],
protocol_list: List[str],
num_rules: int,
)->None:
self.where: WhereType = where
self.ports: List[PortObservation] = [PortObservation(where=[self.where+["port", port_num]]) for port_num in (1,2,3) ]
self.ports: List[PortObservation] = [PortObservation(where=self.where+["port", port_num]) for port_num in (1,2,3) ]
#TODO: check what the port nums are for firewall.
self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
self.default_observation = {
@@ -646,7 +673,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
@classmethod
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation:
where = parent_where+["nodes", config.hostname]
return cls(where=where, ip_list=config.ip_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules)
return cls(where=where, ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules)
class NodesObservation(AbstractObservation, identifier="NODES"):
class ConfigSchema(AbstractObservation.ConfigSchema):
@@ -663,7 +690,9 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
include_nmne: bool
include_num_access: bool
num_ports: int
ip_list: List[str]
wildcard_list: List[str]
port_list: List[int]
protocol_list: List[str]
num_rules: int
@@ -710,13 +739,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
if host_config.num_services is None:
host_config.num_services = config.num_services
if host_config.num_applications is None:
host_config.num_application = config.num_applications
host_config.num_applications = config.num_applications
if host_config.num_folders is None:
host_config.num_folder = config.num_folders
host_config.num_folders = config.num_folders
if host_config.num_files is None:
host_config.num_file = config.num_files
host_config.num_files = config.num_files
if host_config.num_nics is None:
host_config.num_nic = config.num_nics
host_config.num_nics = config.num_nics
if host_config.include_nmne is None:
host_config.include_nmne = config.include_nmne
if host_config.include_num_access is None:
@@ -727,26 +756,24 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
router_config.num_ports = config.num_ports
if router_config.ip_list is None:
router_config.ip_list = config.ip_list
if router_config.wildcard_list is None:
router_config.wildcard_list = config.wildcard_list
if router_config.port_list is None:
router_config.port_list = config.port_list
if router_config.protocol_list is None:
router_config.protocol_list = config.protocol_list
if router_config.num_rules is None:
router_config.num_rules = config.num_rules
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
firewall_config.ip_list = config.ip_list
if firewall_config.wildcard_list is None:
firewall_config.wildcard_list = config.wildcard_list
if firewall_config.port_list is None:
firewall_config.port_list = config.port_list
if firewall_config.protocol_list is None:
firewall_config.protocol_list = config.protocol_list
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
@@ -754,4 +781,4 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)

View File

@@ -147,8 +147,10 @@ class ACLRule(SimComponent):
state["action"] = self.action.value
state["protocol"] = self.protocol.name if self.protocol else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None
state["src_port"] = self.src_port.name if self.src_port else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None
state["dst_port"] = self.dst_port.name if self.dst_port else None
state["match_count"] = self.match_count
return state