244 lines
11 KiB
Python
244 lines
11 KiB
Python
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
|
|
from __future__ import annotations
|
|
|
|
from typing import Dict, List, Optional
|
|
|
|
from gymnasium import spaces
|
|
from gymnasium.core import ObsType
|
|
from pydantic import model_validator
|
|
|
|
from primaite import getLogger
|
|
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
|
from primaite.game.agent.observations.host_observations import HostObservation
|
|
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
|
from primaite.game.agent.observations.router_observation import RouterObservation
|
|
|
|
_LOGGER = getLogger(__name__)
|
|
|
|
|
|
class NodesObservation(AbstractObservation, identifier="NODES"):
|
|
"""Nodes observation, provides status information about nodes within the simulation environment."""
|
|
|
|
class ConfigSchema(AbstractObservation.ConfigSchema):
|
|
"""Configuration schema for NodesObservation."""
|
|
|
|
hosts: List[HostObservation.ConfigSchema] = []
|
|
"""List of configurations for host observations."""
|
|
routers: List[RouterObservation.ConfigSchema] = []
|
|
"""List of configurations for router observations."""
|
|
firewalls: List[FirewallObservation.ConfigSchema] = []
|
|
"""List of configurations for firewall observations."""
|
|
num_services: Optional[int] = None
|
|
"""Number of services."""
|
|
num_applications: Optional[int] = None
|
|
"""Number of applications."""
|
|
num_folders: Optional[int] = None
|
|
"""Number of folders."""
|
|
num_files: Optional[int] = None
|
|
"""Number of files."""
|
|
num_nics: Optional[int] = None
|
|
"""Number of network interface cards (NICs)."""
|
|
include_nmne: Optional[bool] = None
|
|
"""Flag to include nmne."""
|
|
monitored_traffic: Optional[Dict] = None
|
|
"""A dict containing which traffic types are to be included in the observation."""
|
|
include_num_access: Optional[bool] = None
|
|
"""Flag to include the number of accesses."""
|
|
file_system_requires_scan: bool = True
|
|
"""If True, the folder must be scanned to update the health state. If False, the true state is always shown."""
|
|
services_requires_scan: bool = True
|
|
"""If True, the services must be scanned to update the health state.
|
|
If False, the true state is always shown."""
|
|
applications_requires_scan: bool = True
|
|
"""If True, the applications must be scanned to update the health state.
|
|
If False, the true state is always shown."""
|
|
include_users: Optional[bool] = True
|
|
"""If True, report user session information."""
|
|
num_ports: Optional[int] = None
|
|
"""Number of ports."""
|
|
ip_list: Optional[List[str]] = None
|
|
"""List of IP addresses for encoding ACLs."""
|
|
wildcard_list: Optional[List[str]] = None
|
|
"""List of IP wildcards for encoding ACLs."""
|
|
port_list: Optional[List[int]] = None
|
|
"""List of ports for encoding ACLs."""
|
|
protocol_list: Optional[List[str]] = None
|
|
"""List of protocols for encoding ACLs."""
|
|
num_rules: Optional[int] = None
|
|
"""Number of rules ACL rules to show."""
|
|
|
|
@model_validator(mode="after")
|
|
def force_optional_fields(self) -> NodesObservation.ConfigSchema:
|
|
"""Check that options are specified only if they are needed for the nodes that are part of the config."""
|
|
# check for hosts:
|
|
host_fields = (
|
|
self.num_services,
|
|
self.num_applications,
|
|
self.num_folders,
|
|
self.num_files,
|
|
self.num_nics,
|
|
self.include_nmne,
|
|
self.include_num_access,
|
|
)
|
|
router_fields = (
|
|
self.num_ports,
|
|
self.ip_list,
|
|
self.wildcard_list,
|
|
self.port_list,
|
|
self.protocol_list,
|
|
self.num_rules,
|
|
)
|
|
firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules)
|
|
if len(self.hosts) > 0 and any([x is None for x in host_fields]):
|
|
raise ValueError("Configuration error: Host observation options were not fully specified.")
|
|
if len(self.routers) > 0 and any([x is None for x in router_fields]):
|
|
raise ValueError("Configuration error: Router observation options were not fully specified.")
|
|
if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]):
|
|
raise ValueError("Configuration error: Firewall observation options were not fully specified.")
|
|
return self
|
|
|
|
def __init__(
|
|
self,
|
|
where: WhereType,
|
|
hosts: List[HostObservation],
|
|
routers: List[RouterObservation],
|
|
firewalls: List[FirewallObservation],
|
|
) -> None:
|
|
"""
|
|
Initialise a nodes observation instance.
|
|
|
|
:param where: Where in the simulation state dictionary to find the relevant information for nodes.
|
|
A typical location for nodes might be ['network', 'nodes'].
|
|
:type where: WhereType
|
|
:param hosts: List of host observations.
|
|
:type hosts: List[HostObservation]
|
|
:param routers: List of router observations.
|
|
:type routers: List[RouterObservation]
|
|
:param firewalls: List of firewall observations.
|
|
:type firewalls: List[FirewallObservation]
|
|
"""
|
|
self.where: WhereType = where
|
|
|
|
self.hosts: List[HostObservation] = hosts
|
|
self.routers: List[RouterObservation] = routers
|
|
self.firewalls: List[FirewallObservation] = firewalls
|
|
|
|
self.default_observation = {
|
|
**{f"HOST{i}": host.default_observation for i, host in enumerate(self.hosts)},
|
|
**{f"ROUTER{i}": router.default_observation for i, router in enumerate(self.routers)},
|
|
**{f"FIREWALL{i}": firewall.default_observation for i, firewall in enumerate(self.firewalls)},
|
|
}
|
|
|
|
def observe(self, state: Dict) -> ObsType:
|
|
"""
|
|
Generate observation based on the current state of the simulation.
|
|
|
|
:param state: Simulation state dictionary.
|
|
:type state: Dict
|
|
:return: Observation containing status information about nodes.
|
|
:rtype: ObsType
|
|
"""
|
|
obs = {
|
|
**{f"HOST{i}": host.observe(state) for i, host in enumerate(self.hosts)},
|
|
**{f"ROUTER{i}": router.observe(state) for i, router in enumerate(self.routers)},
|
|
**{f"FIREWALL{i}": firewall.observe(state) for i, firewall in enumerate(self.firewalls)},
|
|
}
|
|
return obs
|
|
|
|
@property
|
|
def space(self) -> spaces.Space:
|
|
"""
|
|
Gymnasium space object describing the observation space shape.
|
|
|
|
:return: Gymnasium space representing the observation space for nodes.
|
|
:rtype: spaces.Space
|
|
"""
|
|
space = spaces.Dict(
|
|
{
|
|
**{f"HOST{i}": host.space for i, host in enumerate(self.hosts)},
|
|
**{f"ROUTER{i}": router.space for i, router in enumerate(self.routers)},
|
|
**{f"FIREWALL{i}": firewall.space for i, firewall in enumerate(self.firewalls)},
|
|
}
|
|
)
|
|
return space
|
|
|
|
@classmethod
|
|
def from_config(cls, config: ConfigSchema, parent_where: WhereType = []) -> NodesObservation:
|
|
"""
|
|
Create a nodes observation from a configuration schema.
|
|
|
|
:param config: Configuration schema containing the necessary information for nodes observation.
|
|
:type config: ConfigSchema
|
|
:param parent_where: Where in the simulation state dictionary to find the information about nodes.
|
|
A typical location for nodes might be ['network', 'nodes'].
|
|
:type parent_where: WhereType, optional
|
|
:return: Constructed nodes observation instance.
|
|
:rtype: NodesObservation
|
|
"""
|
|
if not parent_where:
|
|
where = ["network", "nodes"]
|
|
else:
|
|
where = parent_where + ["nodes"]
|
|
|
|
for host_config in config.hosts:
|
|
if host_config.num_services is None:
|
|
host_config.num_services = config.num_services
|
|
if host_config.num_applications is None:
|
|
host_config.num_applications = config.num_applications
|
|
if host_config.num_folders is None:
|
|
host_config.num_folders = config.num_folders
|
|
if host_config.num_files is None:
|
|
host_config.num_files = config.num_files
|
|
if host_config.num_nics is None:
|
|
host_config.num_nics = config.num_nics
|
|
if host_config.include_nmne is None:
|
|
host_config.include_nmne = config.include_nmne
|
|
if host_config.monitored_traffic is None:
|
|
host_config.monitored_traffic = config.monitored_traffic
|
|
if host_config.include_num_access is None:
|
|
host_config.include_num_access = config.include_num_access
|
|
if host_config.file_system_requires_scan is None:
|
|
host_config.file_system_requires_scan = config.file_system_requires_scan
|
|
if host_config.services_requires_scan is None:
|
|
host_config.services_requires_scan = config.services_requires_scan
|
|
if host_config.applications_requires_scan is None:
|
|
host_config.applications_requires_scan = config.applications_requires_scan
|
|
if host_config.include_users is None:
|
|
host_config.include_users = config.include_users
|
|
|
|
for router_config in config.routers:
|
|
if router_config.num_ports is None:
|
|
router_config.num_ports = config.num_ports
|
|
if router_config.ip_list is None:
|
|
router_config.ip_list = config.ip_list
|
|
if router_config.wildcard_list is None:
|
|
router_config.wildcard_list = config.wildcard_list
|
|
if router_config.port_list is None:
|
|
router_config.port_list = config.port_list
|
|
if router_config.protocol_list is None:
|
|
router_config.protocol_list = config.protocol_list
|
|
if router_config.num_rules is None:
|
|
router_config.num_rules = config.num_rules
|
|
if router_config.include_users is None:
|
|
router_config.include_users = config.include_users
|
|
|
|
for firewall_config in config.firewalls:
|
|
if firewall_config.ip_list is None:
|
|
firewall_config.ip_list = config.ip_list
|
|
if firewall_config.wildcard_list is None:
|
|
firewall_config.wildcard_list = config.wildcard_list
|
|
if firewall_config.port_list is None:
|
|
firewall_config.port_list = config.port_list
|
|
if firewall_config.protocol_list is None:
|
|
firewall_config.protocol_list = config.protocol_list
|
|
if firewall_config.num_rules is None:
|
|
firewall_config.num_rules = config.num_rules
|
|
if firewall_config.include_users is None:
|
|
firewall_config.include_users = config.include_users
|
|
|
|
hosts = [HostObservation.from_config(config=c, parent_where=where) for c in config.hosts]
|
|
routers = [RouterObservation.from_config(config=c, parent_where=where) for c in config.routers]
|
|
firewalls = [FirewallObservation.from_config(config=c, parent_where=where) for c in config.firewalls]
|
|
|
|
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
|