# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK """Observations for access control lists (ACLs) on routers and firewalls.""" from __future__ import annotations from typing import Dict, List, Optional from gymnasium import spaces from gymnasium.core import ObsType from primaite import getLogger from primaite.game.agent.observations.observations import AbstractObservation, WhereType from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE from primaite.utils.validation.ip_protocol import IPProtocol from primaite.utils.validation.ipv4_address import StrIP from primaite.utils.validation.port import Port _LOGGER = getLogger(__name__) class ACLObservation(AbstractObservation, discriminator="acl"): """ACL observation, provides information about access control lists within the simulation environment.""" class ConfigSchema(AbstractObservation.ConfigSchema): """Configuration schema for ACLObservation.""" ip_list: Optional[List[StrIP]] = None """List of IP addresses.""" wildcard_list: Optional[List[str]] = None """List of wildcard strings.""" port_list: Optional[List[Port]] = None """List of port names.""" protocol_list: Optional[List[IPProtocol]] = None """List of protocol names.""" num_rules: Optional[int] = None """Number of ACL rules.""" def __init__( self, where: WhereType, num_rules: int, ip_list: List[StrIP], wildcard_list: List[str], port_list: List[Port], protocol_list: List[IPProtocol], ) -> None: """ Initialise an ACL observation instance. :param where: Where in the simulation state dictionary to find the relevant information for this ACL. :type where: WhereType :param num_rules: Number of ACL rules. :type num_rules: int :param ip_list: List of IP addresses. :type ip_list: List[StrIP] :param wildcard_list: List of wildcard strings. :type wildcard_list: List[str] :param port_list: List of port names. :type port_list: List[Port] :param protocol_list: List of protocol names. :type protocol_list: List[IPProtocol] """ self.where = where self.num_rules: int = num_rules self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)} self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)} self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)} self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)} self.default_observation: Dict = { i: { "position": i, "permission": 0, "source_ip_id": 0, "source_wildcard_id": 0, "source_port_id": 0, "dest_ip_id": 0, "dest_wildcard_id": 0, "dest_port_id": 0, "protocol_id": 0, } for i in range(self.num_rules) } def observe(self, state: Dict) -> ObsType: """ Generate observation based on the current state of the simulation. :param state: Simulation state dictionary. :type state: Dict :return: Observation containing ACL rules. :rtype: ObsType """ acl_state: Dict = access_from_nested_dict(state, self.where) if acl_state is NOT_PRESENT_IN_STATE: return self.default_observation obs = {} acl_items = dict(acl_state.items()) for i in range(self.num_rules): rule_state = acl_items[i] if rule_state is None: obs[i] = { "position": i, "permission": 0, "source_ip_id": 0, "source_wildcard_id": 0, "source_port_id": 0, "dest_ip_id": 0, "dest_wildcard_id": 0, "dest_port_id": 0, "protocol_id": 0, } else: src_ip = rule_state["src_ip_address"] src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip] dst_ip = rule_state["dst_ip_address"] dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip] src_wildcard = rule_state["src_wildcard_mask"] src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1) dst_wildcard = rule_state["dst_wildcard_mask"] dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1) src_port = rule_state["src_port"] src_port_id = self.port_to_id.get(src_port, 1) dst_port = rule_state["dst_port"] dst_port_id = self.port_to_id.get(dst_port, 1) protocol = rule_state["protocol"] protocol_id = self.protocol_to_id.get(protocol, 1) obs[i] = { "position": i, "permission": rule_state["action"], "source_ip_id": src_node_id, "source_wildcard_id": src_wildcard_id, "source_port_id": src_port_id, "dest_ip_id": dst_node_id, "dest_wildcard_id": dst_wildcard_id, "dest_port_id": dst_port_id, "protocol_id": protocol_id, } return obs @property def space(self) -> spaces.Space: """ Gymnasium space object describing the observation space shape. :return: Gymnasium space representing the observation space for ACL rules. :rtype: spaces.Space """ return spaces.Dict( { i: spaces.Dict( { "position": spaces.Discrete(self.num_rules), "permission": spaces.Discrete(3), # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) "source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), "source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), "source_port_id": spaces.Discrete(len(self.port_to_id) + 2), "dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2), "dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2), "dest_port_id": spaces.Discrete(len(self.port_to_id) + 2), "protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2), } ) for i in range(self.num_rules) } ) @classmethod def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation: """ Create an ACL observation from a configuration schema. :param config: Configuration schema containing the necessary information for the ACL observation. :type config: ConfigSchema :param parent_where: Where in the simulation state dictionary to find the information about this ACL's parent node. A typical location for a node might be ['network', 'nodes', ]. :type parent_where: WhereType, optional :return: Constructed ACL observation instance. :rtype: ACLObservation """ return cls( where=parent_where + ["acl", "acl"], num_rules=config.num_rules, ip_list=config.ip_list, wildcard_list=config.wildcard_list, port_list=config.port_list, protocol_list=config.protocol_list, )