From 0a65f32adfa47cd84c640f3ae9c1b0aed0bc1b94 Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 25 Jan 2024 09:27:08 +0000 Subject: [PATCH] Fix ACL observations --- src/primaite/game/agent/observations.py | 35 +++++++++++++------ .../network/hardware/nodes/router.py | 12 +++---- 2 files changed, 31 insertions(+), 16 deletions(-) diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py index cac5b91e..b7962827 100644 --- a/src/primaite/game/agent/observations.py +++ b/src/primaite/game/agent/observations.py @@ -1,5 +1,6 @@ """Manages the observation space for the agent.""" from abc import ABC, abstractmethod +from ipaddress import IPv4Address from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium import spaces @@ -648,10 +649,13 @@ class AclObservation(AbstractObservation): # TODO: what if the ACL has more rules than num of max rules for obs space obs = {} - for i, rule_state in acl_state.items(): + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] if rule_state is None: - obs[i + 1] = { - "position": i, + obs[i] = { + "position": i - 1, "permission": 0, "source_node_id": 0, "source_port": 0, @@ -660,15 +664,26 @@ class AclObservation(AbstractObservation): "protocol": 0, } else: - obs[i + 1] = { - "position": i, + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, "permission": rule_state["action"], - "source_node_id": self.node_to_id[rule_state["src_ip_address"]], - "source_port": self.port_to_id[rule_state["src_port"]], - "dest_node_id": self.node_to_id[rule_state["dst_ip_address"]], - "dest_port": self.port_to_id[rule_state["dst_port"]], - "protocol": self.protocol_to_id[rule_state["protocol"]], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, } + i += 1 return obs @property diff --git a/src/primaite/simulator/network/hardware/nodes/router.py b/src/primaite/simulator/network/hardware/nodes/router.py index bb923d62..0c5d0ce9 100644 --- a/src/primaite/simulator/network/hardware/nodes/router.py +++ b/src/primaite/simulator/network/hardware/nodes/router.py @@ -19,8 +19,8 @@ from primaite.simulator.system.core.sys_log import SysLog class ACLAction(Enum): """Enum for defining the ACL action types.""" - DENY = 0 PERMIT = 1 + DENY = 2 class ACLRule(SimComponent): @@ -66,11 +66,11 @@ class ACLRule(SimComponent): """ state = super().describe_state() state["action"] = self.action.value - state["protocol"] = self.protocol.value if self.protocol else None + state["protocol"] = self.protocol.name if self.protocol else None state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None - state["src_port"] = self.src_port.value if self.src_port else None + state["src_port"] = self.src_port.name if self.src_port else None state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None - state["dst_port"] = self.dst_port.value if self.dst_port else None + state["dst_port"] = self.dst_port.name if self.dst_port else None return state @@ -733,8 +733,8 @@ class Router(Node): :return: A dictionary representing the current state. """ state = super().describe_state() - state["num_ports"] = (self.num_ports,) - state["acl"] = (self.acl.describe_state(),) + state["num_ports"] = self.num_ports + state["acl"] = self.acl.describe_state() return state def route_frame(self, frame: Frame, from_nic: NIC, re_attempt: bool = False) -> None: