Fix ACL observations
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
@@ -648,10 +649,13 @@ class AclObservation(AbstractObservation):
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
for i, rule_state in acl_state.items():
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
@@ -660,15 +664,26 @@ class AclObservation(AbstractObservation):
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
obs[i + 1] = {
|
||||
"position": i,
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": self.node_to_id[rule_state["src_ip_address"]],
|
||||
"source_port": self.port_to_id[rule_state["src_port"]],
|
||||
"dest_node_id": self.node_to_id[rule_state["dst_ip_address"]],
|
||||
"dest_port": self.port_to_id[rule_state["dst_port"]],
|
||||
"protocol": self.protocol_to_id[rule_state["protocol"]],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
|
||||
@@ -19,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog
|
||||
class ACLAction(Enum):
|
||||
"""Enum for defining the ACL action types."""
|
||||
|
||||
DENY = 0
|
||||
PERMIT = 1
|
||||
DENY = 2
|
||||
|
||||
|
||||
class ACLRule(SimComponent):
|
||||
@@ -66,11 +66,11 @@ class ACLRule(SimComponent):
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.value if self.protocol else None
|
||||
state["protocol"] = self.protocol.name if self.protocol else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_port"] = self.src_port.value if self.src_port else None
|
||||
state["src_port"] = self.src_port.name if self.src_port else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_port"] = self.dst_port.value if self.dst_port else None
|
||||
state["dst_port"] = self.dst_port.name if self.dst_port else None
|
||||
return state
|
||||
|
||||
|
||||
@@ -733,8 +733,8 @@ class Router(Node):
|
||||
:return: A dictionary representing the current state.
|
||||
"""
|
||||
state = super().describe_state()
|
||||
state["num_ports"] = (self.num_ports,)
|
||||
state["acl"] = (self.acl.describe_state(),)
|
||||
state["num_ports"] = self.num_ports
|
||||
state["acl"] = self.acl.describe_state()
|
||||
return state
|
||||
|
||||
def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None:
|
||||
|
||||
Reference in New Issue
Block a user