Fix ACL observations

This commit is contained in:
Marek Wolan
2024-01-25 09:27:08 +00:00
parent 88c1d16f11
commit 0a65f32adf
2 changed files with 31 additions and 16 deletions

View File

@@ -1,5 +1,6 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
@@ -648,10 +649,13 @@ class AclObservation(AbstractObservation):
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
for i, rule_state in acl_state.items():
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i + 1] = {
"position": i,
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
@@ -660,15 +664,26 @@ class AclObservation(AbstractObservation):
"protocol": 0,
}
else:
obs[i + 1] = {
"position": i,
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
"source_port": self.port_to_id[rule_state["src_port"]],
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
"dest_port": self.port_to_id[rule_state["dst_port"]],
"protocol": self.protocol_to_id[rule_state["protocol"]],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property

View File

@@ -19,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog
class ACLAction(Enum):
"""Enum for defining the ACL action types."""
DENY = 0
PERMIT = 1
DENY = 2
class ACLRule(SimComponent):
@@ -66,11 +66,11 @@ class ACLRule(SimComponent):
"""
state = super().describe_state()
state["action"] = self.action.value
state["protocol"] = self.protocol.value if self.protocol else None
state["protocol"] = self.protocol.name if self.protocol else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_port"] = self.src_port.value if self.src_port else None
state["src_port"] = self.src_port.name if self.src_port else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_port"] = self.dst_port.value if self.dst_port else None
state["dst_port"] = self.dst_port.name if self.dst_port else None
return state
@@ -733,8 +733,8 @@ class Router(Node):
:return: A dictionary representing the current state.
"""
state = super().describe_state()
state["num_ports"] = (self.num_ports,)
state["acl"] = (self.acl.describe_state(),)
state["num_ports"] = self.num_ports
state["acl"] = self.acl.describe_state()
return state
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None: