2025-01-02 15:05:06 +00:00
|
|
|
# © Crown-owned copyright 2025, Defence Science and Technology Laboratory UK
|
2025-03-13 15:07:32 +00:00
|
|
|
"""Observations for access control lists (ACLs) on routers and firewalls."""
|
2024-03-29 14:14:03 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
|
2024-04-03 22:16:54 +01:00
|
|
|
from typing import Dict, List, Optional
|
2024-03-29 14:14:03 +00:00
|
|
|
|
|
|
|
|
from gymnasium import spaces
|
|
|
|
|
from gymnasium.core import ObsType
|
|
|
|
|
|
|
|
|
|
from primaite import getLogger
|
|
|
|
|
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
|
|
|
|
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
2025-01-31 12:18:52 +00:00
|
|
|
from primaite.utils.validation.ip_protocol import IPProtocol
|
|
|
|
|
from primaite.utils.validation.ipv4_address import StrIP
|
|
|
|
|
from primaite.utils.validation.port import Port
|
2024-03-29 14:14:03 +00:00
|
|
|
|
|
|
|
|
_LOGGER = getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
2025-02-03 16:24:03 +00:00
|
|
|
class ACLObservation(AbstractObservation, discriminator="acl"):
|
2024-03-29 14:14:03 +00:00
|
|
|
"""ACL observation, provides information about access control lists within the simulation environment."""
|
|
|
|
|
|
|
|
|
|
class ConfigSchema(AbstractObservation.ConfigSchema):
|
|
|
|
|
"""Configuration schema for ACLObservation."""
|
|
|
|
|
|
2025-01-31 12:18:52 +00:00
|
|
|
ip_list: Optional[List[StrIP]] = None
|
2024-03-29 14:14:03 +00:00
|
|
|
"""List of IP addresses."""
|
|
|
|
|
wildcard_list: Optional[List[str]] = None
|
|
|
|
|
"""List of wildcard strings."""
|
2025-01-31 12:18:52 +00:00
|
|
|
port_list: Optional[List[Port]] = None
|
2025-01-21 13:08:36 +00:00
|
|
|
"""List of port names."""
|
2025-01-31 12:18:52 +00:00
|
|
|
protocol_list: Optional[List[IPProtocol]] = None
|
2024-03-29 14:14:03 +00:00
|
|
|
"""List of protocol names."""
|
|
|
|
|
num_rules: Optional[int] = None
|
|
|
|
|
"""Number of ACL rules."""
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
where: WhereType,
|
|
|
|
|
num_rules: int,
|
2025-01-31 12:18:52 +00:00
|
|
|
ip_list: List[StrIP],
|
2024-03-29 14:14:03 +00:00
|
|
|
wildcard_list: List[str],
|
2025-01-31 12:18:52 +00:00
|
|
|
port_list: List[Port],
|
|
|
|
|
protocol_list: List[IPProtocol],
|
2024-03-29 14:14:03 +00:00
|
|
|
) -> None:
|
|
|
|
|
"""
|
2024-03-31 17:31:10 +01:00
|
|
|
Initialise an ACL observation instance.
|
2024-03-29 14:14:03 +00:00
|
|
|
|
|
|
|
|
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
|
|
|
|
|
:type where: WhereType
|
|
|
|
|
:param num_rules: Number of ACL rules.
|
|
|
|
|
:type num_rules: int
|
|
|
|
|
:param ip_list: List of IP addresses.
|
2025-01-31 12:18:52 +00:00
|
|
|
:type ip_list: List[StrIP]
|
2024-03-29 14:14:03 +00:00
|
|
|
:param wildcard_list: List of wildcard strings.
|
|
|
|
|
:type wildcard_list: List[str]
|
2025-01-21 13:08:36 +00:00
|
|
|
:param port_list: List of port names.
|
2025-01-31 12:18:52 +00:00
|
|
|
:type port_list: List[Port]
|
2024-03-29 14:14:03 +00:00
|
|
|
:param protocol_list: List of protocol names.
|
2025-01-31 12:18:52 +00:00
|
|
|
:type protocol_list: List[IPProtocol]
|
2024-03-29 14:14:03 +00:00
|
|
|
"""
|
|
|
|
|
self.where = where
|
|
|
|
|
self.num_rules: int = num_rules
|
2024-04-01 00:54:55 +01:00
|
|
|
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
|
|
|
|
|
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
|
2025-01-21 13:08:36 +00:00
|
|
|
self.port_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
2024-04-01 00:54:55 +01:00
|
|
|
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
2024-03-29 14:14:03 +00:00
|
|
|
self.default_observation: Dict = {
|
2025-02-26 19:31:43 +00:00
|
|
|
i: {
|
2024-03-29 14:14:03 +00:00
|
|
|
"position": i,
|
|
|
|
|
"permission": 0,
|
|
|
|
|
"source_ip_id": 0,
|
|
|
|
|
"source_wildcard_id": 0,
|
|
|
|
|
"source_port_id": 0,
|
|
|
|
|
"dest_ip_id": 0,
|
|
|
|
|
"dest_wildcard_id": 0,
|
|
|
|
|
"dest_port_id": 0,
|
|
|
|
|
"protocol_id": 0,
|
|
|
|
|
}
|
2024-04-01 16:50:59 +01:00
|
|
|
for i in range(self.num_rules)
|
2024-03-29 14:14:03 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
def observe(self, state: Dict) -> ObsType:
|
|
|
|
|
"""
|
|
|
|
|
Generate observation based on the current state of the simulation.
|
|
|
|
|
|
|
|
|
|
:param state: Simulation state dictionary.
|
|
|
|
|
:type state: Dict
|
|
|
|
|
:return: Observation containing ACL rules.
|
|
|
|
|
:rtype: ObsType
|
|
|
|
|
"""
|
|
|
|
|
acl_state: Dict = access_from_nested_dict(state, self.where)
|
|
|
|
|
if acl_state is NOT_PRESENT_IN_STATE:
|
|
|
|
|
return self.default_observation
|
|
|
|
|
obs = {}
|
|
|
|
|
acl_items = dict(acl_state.items())
|
2025-02-26 19:31:43 +00:00
|
|
|
for i in range(self.num_rules):
|
2024-03-29 14:14:03 +00:00
|
|
|
rule_state = acl_items[i]
|
|
|
|
|
if rule_state is None:
|
|
|
|
|
obs[i] = {
|
2025-02-26 19:31:43 +00:00
|
|
|
"position": i,
|
2024-03-29 14:14:03 +00:00
|
|
|
"permission": 0,
|
|
|
|
|
"source_ip_id": 0,
|
|
|
|
|
"source_wildcard_id": 0,
|
|
|
|
|
"source_port_id": 0,
|
|
|
|
|
"dest_ip_id": 0,
|
|
|
|
|
"dest_wildcard_id": 0,
|
|
|
|
|
"dest_port_id": 0,
|
|
|
|
|
"protocol_id": 0,
|
|
|
|
|
}
|
|
|
|
|
else:
|
|
|
|
|
src_ip = rule_state["src_ip_address"]
|
2024-04-01 00:54:55 +01:00
|
|
|
src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip]
|
2024-03-29 14:14:03 +00:00
|
|
|
dst_ip = rule_state["dst_ip_address"]
|
2024-04-01 00:54:55 +01:00
|
|
|
dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip]
|
|
|
|
|
src_wildcard = rule_state["src_wildcard_mask"]
|
2024-03-29 14:14:03 +00:00
|
|
|
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
|
2024-04-01 00:54:55 +01:00
|
|
|
dst_wildcard = rule_state["dst_wildcard_mask"]
|
2024-03-29 14:14:03 +00:00
|
|
|
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
|
2024-04-01 00:54:55 +01:00
|
|
|
src_port = rule_state["src_port"]
|
2024-03-29 14:14:03 +00:00
|
|
|
src_port_id = self.port_to_id.get(src_port, 1)
|
2024-04-01 00:54:55 +01:00
|
|
|
dst_port = rule_state["dst_port"]
|
2024-03-29 14:14:03 +00:00
|
|
|
dst_port_id = self.port_to_id.get(dst_port, 1)
|
|
|
|
|
protocol = rule_state["protocol"]
|
|
|
|
|
protocol_id = self.protocol_to_id.get(protocol, 1)
|
|
|
|
|
obs[i] = {
|
2025-02-26 19:31:43 +00:00
|
|
|
"position": i,
|
2024-03-29 14:14:03 +00:00
|
|
|
"permission": rule_state["action"],
|
|
|
|
|
"source_ip_id": src_node_id,
|
|
|
|
|
"source_wildcard_id": src_wildcard_id,
|
|
|
|
|
"source_port_id": src_port_id,
|
2024-04-01 00:54:55 +01:00
|
|
|
"dest_ip_id": dst_node_id,
|
2024-03-29 14:14:03 +00:00
|
|
|
"dest_wildcard_id": dst_wildcard_id,
|
|
|
|
|
"dest_port_id": dst_port_id,
|
|
|
|
|
"protocol_id": protocol_id,
|
|
|
|
|
}
|
|
|
|
|
return obs
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def space(self) -> spaces.Space:
|
|
|
|
|
"""
|
|
|
|
|
Gymnasium space object describing the observation space shape.
|
|
|
|
|
|
|
|
|
|
:return: Gymnasium space representing the observation space for ACL rules.
|
|
|
|
|
:rtype: spaces.Space
|
|
|
|
|
"""
|
|
|
|
|
return spaces.Dict(
|
|
|
|
|
{
|
2025-02-26 19:31:43 +00:00
|
|
|
i: spaces.Dict(
|
2024-03-29 14:14:03 +00:00
|
|
|
{
|
|
|
|
|
"position": spaces.Discrete(self.num_rules),
|
|
|
|
|
"permission": spaces.Discrete(3),
|
|
|
|
|
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
|
|
|
|
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
|
|
|
|
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
|
|
|
|
|
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
|
|
|
|
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
|
|
|
|
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
|
|
|
|
|
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
|
|
|
|
"protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2),
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
for i in range(self.num_rules)
|
|
|
|
|
}
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
@classmethod
|
2024-04-03 22:16:54 +01:00
|
|
|
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> ACLObservation:
|
2024-03-29 14:14:03 +00:00
|
|
|
"""
|
|
|
|
|
Create an ACL observation from a configuration schema.
|
|
|
|
|
|
|
|
|
|
:param config: Configuration schema containing the necessary information for the ACL observation.
|
|
|
|
|
:type config: ConfigSchema
|
|
|
|
|
:param parent_where: Where in the simulation state dictionary to find the information about this ACL's
|
|
|
|
|
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
|
|
|
|
:type parent_where: WhereType, optional
|
|
|
|
|
:return: Constructed ACL observation instance.
|
|
|
|
|
:rtype: ACLObservation
|
|
|
|
|
"""
|
|
|
|
|
return cls(
|
|
|
|
|
where=parent_where + ["acl", "acl"],
|
|
|
|
|
num_rules=config.num_rules,
|
|
|
|
|
ip_list=config.ip_list,
|
|
|
|
|
wildcard_list=config.wildcard_list,
|
|
|
|
|
port_list=config.port_list,
|
|
|
|
|
protocol_list=config.protocol_list,
|
|
|
|
|
)
|