Fixed observations
This commit is contained in:
@@ -82,7 +82,7 @@ class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
file_name: str
|
||||
include_num_access : bool = False
|
||||
include_num_access: Optional[bool] = None
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool)->None:
|
||||
self.where: WhereType = where
|
||||
@@ -118,8 +118,8 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
folder_name: str
|
||||
files: List[FileObservation.ConfigSchema] = []
|
||||
num_files : int = 0
|
||||
include_num_access : bool = False
|
||||
num_files : Optional[int] = None
|
||||
include_num_access : Optional[bool] = None
|
||||
|
||||
def __init__(self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool)->None:
|
||||
self.where: WhereType = where
|
||||
@@ -179,7 +179,7 @@ class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
nic_num: int
|
||||
include_nmne: bool = False
|
||||
include_nmne: Optional[bool] = None
|
||||
|
||||
|
||||
def __init__(self, where: WhereType, include_nmne: bool)->None:
|
||||
@@ -233,13 +233,13 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
applications: List[ApplicationObservation.ConfigSchema] = []
|
||||
folders: List[FolderObservation.ConfigSchema] = []
|
||||
network_interfaces: List[NICObservation.ConfigSchema] = []
|
||||
num_services: int
|
||||
num_applications: int
|
||||
num_folders: int
|
||||
num_files: int
|
||||
num_nics: int
|
||||
include_nmne: bool
|
||||
include_num_access: bool
|
||||
num_services: Optional[int] = None
|
||||
num_applications: Optional[int] = None
|
||||
num_folders: Optional[int] = None
|
||||
num_files: Optional[int] = None
|
||||
num_nics: Optional[int] = None
|
||||
include_nmne: Optional[bool] = None
|
||||
include_num_access: Optional[bool] = None
|
||||
|
||||
def __init__(self,
|
||||
where: WhereType,
|
||||
@@ -296,6 +296,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
|
||||
self.default_observation: ObsType = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"APPLICATIONS": {i + 1: a.default_observation for i, a in enumerate(self.applications)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
@@ -311,6 +312,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {
|
||||
@@ -324,6 +326,7 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
def space(self) -> spaces.Space:
|
||||
shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"APPLICATIONS": spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict(
|
||||
@@ -393,15 +396,17 @@ class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
|
||||
class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
ip_list: List[IPv4Address]
|
||||
port_list: List[int]
|
||||
protocol_list: List[str]
|
||||
num_rules: int
|
||||
ip_list: Optional[List[IPv4Address]] = None
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
port_list: Optional[List[int]] = None
|
||||
protocol_list: Optional[List[str]] = None
|
||||
num_rules: Optional[int] = None
|
||||
|
||||
def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], port_list: List[int],protocol_list: List[str])->None:
|
||||
def __init__(self, where: WhereType, num_rules: int, ip_list: List[IPv4Address], wildcard_list: List[str], port_list: List[int],protocol_list: List[str])->None:
|
||||
self.where = where
|
||||
self.num_rules: int = num_rules
|
||||
self.ip_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(ip_list)}
|
||||
self.wildcard_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(wildcard_list)}
|
||||
self.port_to_id: Dict[int, int] = {i+2:p for i,p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {i+2:p for i,p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
@@ -409,10 +414,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
"source_port_id": 0,
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
@@ -431,30 +438,38 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
"source_port_id": 0,
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
src_node_id = self.ip_to_id.get(src_ip, 1)
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
dst_node_ip = self.ip_to_id.get(dst_ip, 1)
|
||||
src_wildcard = rule_state["source_wildcard_id"]
|
||||
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
|
||||
dst_wildcard = rule_state["dest_wildcard_id"]
|
||||
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
|
||||
src_port = rule_state["source_port_id"]
|
||||
src_port_id = self.port_to_id.get(src_port, 1)
|
||||
dst_port = rule_state["dest_port_id"]
|
||||
dst_port_id = self.port_to_id.get(dst_port, 1)
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"source_ip_id": src_node_id,
|
||||
"source_wildcard_id": src_wildcard_id,
|
||||
"source_port_id": src_port_id,
|
||||
"dest_ip_id": dst_node_ip,
|
||||
"dest_wildcard_id": dst_wildcard_id,
|
||||
"dest_port_id": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
@@ -462,7 +477,6 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
raise NotImplementedError("TODO: need to add wildcard id.")
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
@@ -471,10 +485,12 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2),
|
||||
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id)+2),
|
||||
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
@@ -489,20 +505,22 @@ class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
where = parent_where+["acl", "acl"],
|
||||
num_rules = config.num_rules,
|
||||
ip_list = config.ip_list,
|
||||
ports = config.port_list,
|
||||
protocols = config.protocol_list
|
||||
wildcard_list = config.wildcard_list,
|
||||
port_list = config.port_list,
|
||||
protocol_list = config.protocol_list
|
||||
)
|
||||
|
||||
class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
hostname: str
|
||||
ports: List[PortObservation.ConfigSchema]
|
||||
num_ports: int
|
||||
acl: ACLObservation.ConfigSchema
|
||||
ip_list: List[str]
|
||||
port_list: List[int]
|
||||
protocol_list: List[str]
|
||||
num_rules: int
|
||||
ports: Optional[List[PortObservation.ConfigSchema]] = None
|
||||
num_ports: Optional[int] = None
|
||||
acl: Optional[ACLObservation.ConfigSchema] = None
|
||||
ip_list: Optional[List[str]] = None
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
port_list: Optional[List[int]] = None
|
||||
protocol_list: Optional[List[str]] = None
|
||||
num_rules: Optional[int] = None
|
||||
|
||||
def __init__(self,
|
||||
where: WhereType,
|
||||
@@ -540,7 +558,7 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
return spaces.Dict({
|
||||
"PORTS": {i+1:p.space for i,p in self.ports},
|
||||
"PORTS": spaces.Dict({i+1:p.space for i,p in enumerate(self.ports)}),
|
||||
"ACL": self.acl.space
|
||||
})
|
||||
|
||||
@@ -548,15 +566,22 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> RouterObservation:
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
|
||||
if config.acl is None:
|
||||
config.acl = ACLObservation.ConfigSchema()
|
||||
if config.acl.num_rules is None:
|
||||
config.acl.num_rules = config.num_rules
|
||||
if config.acl.ip_list is None:
|
||||
config.acl.ip_list = config.ip_list
|
||||
if config.acl.wildcard_list is None:
|
||||
config.acl.wildcard_list = config.wildcard_list
|
||||
if config.acl.port_list is None:
|
||||
config.acl.port_list = config.port_list
|
||||
if config.acl.protocol_list is None:
|
||||
config.acl.protocol_list = config.protocol_list
|
||||
|
||||
if config.ports is None:
|
||||
config.ports = [PortObservation.ConfigSchema(port_id=i+1) for i in range(config.num_ports)]
|
||||
|
||||
ports = [PortObservation.from_config(config=c, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, parent_where=where)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
|
||||
@@ -564,30 +589,32 @@ class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
hostname: str
|
||||
ip_list: List[str]
|
||||
port_list: List[int]
|
||||
protocol_list: List[str]
|
||||
num_rules: int
|
||||
ip_list: Optional[List[str]] = None
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
port_list: Optional[List[int]] = None
|
||||
protocol_list: Optional[List[str]] = None
|
||||
num_rules: Optional[int] = None
|
||||
|
||||
|
||||
def __init__(self,
|
||||
where: WhereType,
|
||||
ip_list: List[str],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
)->None:
|
||||
self.where: WhereType = where
|
||||
|
||||
self.ports: List[PortObservation] = [PortObservation(where=[self.where+["port", port_num]]) for port_num in (1,2,3) ]
|
||||
self.ports: List[PortObservation] = [PortObservation(where=self.where+["port", port_num]) for port_num in (1,2,3) ]
|
||||
#TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.internal_inbound_acl = ACLObservation(where = self.where+["acl","internal","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.internal_outbound_acl = ACLObservation(where = self.where+["acl","internal","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.dmz_inbound_acl = ACLObservation(where = self.where+["acl","dmz","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.dmz_outbound_acl = ACLObservation(where = self.where+["acl","dmz","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.external_inbound_acl = ACLObservation(where = self.where+["acl","external","inbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
|
||||
self.external_outbound_acl = ACLObservation(where = self.where+["acl","external","outbound"], num_rules= num_rules, ip_list=ip_list, wildcard_list=wildcard_list, port_list=port_list, protocol_list=protocol_list)
|
||||
|
||||
|
||||
self.default_observation = {
|
||||
@@ -646,7 +673,7 @@ class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, parent_where: WhereType = [] ) -> FirewallObservation:
|
||||
where = parent_where+["nodes", config.hostname]
|
||||
return cls(where=where, ip_list=config.ip_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules)
|
||||
return cls(where=where, ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, protocol_list=config.protocol_list, num_rules=config.num_rules)
|
||||
|
||||
class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
@@ -663,7 +690,9 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
include_nmne: bool
|
||||
include_num_access: bool
|
||||
|
||||
num_ports: int
|
||||
ip_list: List[str]
|
||||
wildcard_list: List[str]
|
||||
port_list: List[int]
|
||||
protocol_list: List[str]
|
||||
num_rules: int
|
||||
@@ -710,13 +739,13 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
if host_config.num_services is None:
|
||||
host_config.num_services = config.num_services
|
||||
if host_config.num_applications is None:
|
||||
host_config.num_application = config.num_applications
|
||||
host_config.num_applications = config.num_applications
|
||||
if host_config.num_folders is None:
|
||||
host_config.num_folder = config.num_folders
|
||||
host_config.num_folders = config.num_folders
|
||||
if host_config.num_files is None:
|
||||
host_config.num_file = config.num_files
|
||||
host_config.num_files = config.num_files
|
||||
if host_config.num_nics is None:
|
||||
host_config.num_nic = config.num_nics
|
||||
host_config.num_nics = config.num_nics
|
||||
if host_config.include_nmne is None:
|
||||
host_config.include_nmne = config.include_nmne
|
||||
if host_config.include_num_access is None:
|
||||
@@ -727,26 +756,24 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
router_config.num_ports = config.num_ports
|
||||
if router_config.ip_list is None:
|
||||
router_config.ip_list = config.ip_list
|
||||
|
||||
if router_config.wildcard_list is None:
|
||||
router_config.wildcard_list = config.wildcard_list
|
||||
if router_config.port_list is None:
|
||||
router_config.port_list = config.port_list
|
||||
|
||||
if router_config.protocol_list is None:
|
||||
router_config.protocol_list = config.protocol_list
|
||||
|
||||
if router_config.num_rules is None:
|
||||
router_config.num_rules = config.num_rules
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
firewall_config.ip_list = config.ip_list
|
||||
|
||||
if firewall_config.wildcard_list is None:
|
||||
firewall_config.wildcard_list = config.wildcard_list
|
||||
if firewall_config.port_list is None:
|
||||
firewall_config.port_list = config.port_list
|
||||
|
||||
if firewall_config.protocol_list is None:
|
||||
firewall_config.protocol_list = config.protocol_list
|
||||
|
||||
if firewall_config.num_rules is None:
|
||||
firewall_config.num_rules = config.num_rules
|
||||
|
||||
@@ -754,4 +781,4 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
||||
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
|
||||
|
||||
cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
|
||||
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
|
||||
|
||||
@@ -147,8 +147,10 @@ class ACLRule(SimComponent):
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.name if self.protocol else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None
|
||||
state["src_port"] = self.src_port.name if self.src_port else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None
|
||||
state["dst_port"] = self.dst_port.name if self.dst_port else None
|
||||
state["match_count"] = self.match_count
|
||||
return state
|
||||
|
||||
Reference in New Issue
Block a user