diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml index 06028ee1..d810e58a 100644 --- a/src/primaite/config/_package_data/data_manipulation.yaml +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -41,8 +41,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -91,8 +90,7 @@ agents: 0: 0.3 1: 0.6 2: 0.1 - observation_space: - type: UC2GreenObservation + observation_space: null action_space: action_list: - type: DONOTHING @@ -141,10 +139,7 @@ agents: team: RED type: RedDatabaseCorruptingAgent - observation_space: - type: UC2RedObservation - options: - nodes: {} + observation_space: null action_space: action_list: @@ -177,102 +172,73 @@ agents: type: ProxyAgent observation_space: - - type: NODES - label: NODES # What is the dictionary key called - options: - hosts: - - hostname: domain_controller - - hostname: web_server - - hostname: database_server - - hostname: backup_server - - hostname: security_suite - - hostname: client_1 - - hostname: client_2 - routers: - - hostname: router_1 - firewalls: {} - - num_host_services: 1 - num_host_applications: 0 - num_host_folders: 1 - num_host_files: 1 - num_host_network_interfaces: 2 - num_router_ports: 4 - num_acl_rules: 10 - num_firewall_ports: 4 - firewalls_internal_inbound_acl: true - firewalls_internal_outbound_acl: true - firewalls_dmz_inbound_acl: true - firewalls_dmz_outbound_acl: true - firewalls_external_inbound_acl: true - firewalls_external_outbound_acl: true - - type: LINKS - label: "LINKS" - options: - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - - observation_space: - type: UC2BlueObservation + type: CUSTOM options: - nodes: - - node_hostname: domain_controller - services: - - service_name: DNSServer - - node_hostname: web_server - services: - - service_name: WebServer - - node_hostname: database_server - folders: - - folder_name: database - files: - - file_name: database.db - - node_hostname: backup_server - - node_hostname: security_suite - - node_hostname: client_1 - - node_hostname: client_2 - links: - - link_ref: router_1___switch_1 - - link_ref: router_1___switch_2 - - link_ref: switch_1___domain_controller - - link_ref: switch_1___web_server - - link_ref: switch_1___database_server - - link_ref: switch_1___backup_server - - link_ref: switch_1___security_suite - - link_ref: switch_2___client_1 - - link_ref: switch_2___client_2 - - link_ref: switch_2___security_suite - acl: - options: - max_acl_rules: 10 - router_hostname: router_1 - ip_address_order: - - node_hostname: domain_controller - nic_num: 1 - - node_hostname: web_server - nic_num: 1 - - node_hostname: database_server - nic_num: 1 - - node_hostname: backup_server - nic_num: 1 - - node_hostname: security_suite - nic_num: 1 - - node_hostname: client_1 - nic_num: 1 - - node_hostname: client_2 - nic_num: 1 - - node_hostname: security_suite - nic_num: 2 - ics: null + components: + - type: NODES + label: NODES + options: + hosts: + - hostname: domain_controller + - hostname: web_server + services: + - service_name: WebServer + - hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - hostname: backup_server + - hostname: security_suite + - hostname: client_1 + - hostname: client_2 + num_services: 1 + num_applications: 0 + num_folders: 1 + num_files: 1 + num_nics: 2 + include_num_access: false + include_nmne: true + routers: + - hostname: router_1 + num_ports: 0 + ip_list: + - 192.168.1.10 + - 192.168.1.12 + - 192.168.1.14 + - 192.168.1.16 + - 192.168.1.110 + - 192.168.10.21 + - 192.168.10.22 + - 192.168.10.110 + wildcard_list: + - 0.0.0.1 + port_list: + - 80 + - 5432 + protocol_list: + - ICMP + - TCP + - UDP + num_rules: 10 + + - type: LINKS + label: LINKS + options: + link_references: + - router_1___switch_1 + - router_1___switch_2 + - switch_1___domain_controller + - switch_1___web_server + - switch_1___database_server + - switch_1___backup_server + - switch_1___security_suite + - switch_2___client_1 + - switch_2___client_2 + - switch_2___security_suite + - type: "NONE" + label: ICS + options: {} action_space: action_list: diff --git a/src/primaite/game/agent/observations/__init__.py b/src/primaite/game/agent/observations/__init__.py index e69de29b..b9d97ae6 100644 --- a/src/primaite/game/agent/observations/__init__.py +++ b/src/primaite/game/agent/observations/__init__.py @@ -0,0 +1,12 @@ +# flake8: noqa +from primaite.game.agent.observations.acl_observation import ACLObservation +from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation +from primaite.game.agent.observations.firewall_observation import FirewallObservation +from primaite.game.agent.observations.host_observations import HostObservation +from primaite.game.agent.observations.link_observation import LinkObservation, LinksObservation +from primaite.game.agent.observations.nic_observations import NICObservation, PortObservation +from primaite.game.agent.observations.node_observations import NodesObservation +from primaite.game.agent.observations.observation_manager import NestedObservation, NullObservation, ObservationManager +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.router_observation import RouterObservation +from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation diff --git a/src/primaite/game/agent/observations/host_observations.py b/src/primaite/game/agent/observations/host_observations.py index d71583b3..3ee5f2c7 100644 --- a/src/primaite/game/agent/observations/host_observations.py +++ b/src/primaite/game/agent/observations/host_observations.py @@ -94,6 +94,8 @@ class HostObservation(AbstractObservation, identifier="HOST"): """ self.where: WhereType = where + self.include_num_access = include_num_access + # Ensure lists have lengths equal to specified counts by truncating or padding self.services: List[ServiceObservation] = services while len(self.services) < num_services: @@ -135,9 +137,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, "operating_status": 0, - "num_file_creations": 0, - "num_file_deletions": 0, } + if self.include_num_access: + self.default_observation["num_file_creations"] = 0 + self.default_observation["num_file_deletions"] = 0 def observe(self, state: Dict) -> ObsType: """ @@ -160,8 +163,9 @@ class HostObservation(AbstractObservation, identifier="HOST"): obs["NICS"] = { i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) } - obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] - obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] + if self.include_num_access: + obs["num_file_creations"] = node_state["file_system"]["num_file_creations"] + obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"] return obs @property @@ -180,9 +184,10 @@ class HostObservation(AbstractObservation, identifier="HOST"): "NICS": spaces.Dict( {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} ), - "num_file_creations": spaces.Discrete(4), - "num_file_deletions": spaces.Discrete(4), } + if self.include_num_access: + shape["num_file_creations"] = spaces.Discrete(4) + shape["num_file_deletions"] = spaces.Discrete(4) return spaces.Dict(shape) @classmethod diff --git a/src/primaite/game/agent/observations/link_observation.py b/src/primaite/game/agent/observations/link_observation.py index f810bb36..be08657d 100644 --- a/src/primaite/game/agent/observations/link_observation.py +++ b/src/primaite/game/agent/observations/link_observation.py @@ -132,7 +132,7 @@ class LinksObservation(AbstractObservation, identifier="LINKS"): :return: Gymnasium space representing the observation space for multiple links. :rtype: spaces.Space """ - return {i + 1: l.space for i, l in enumerate(self.links)} + return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)}) @classmethod def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation: diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py index 7d227bb7..dce33a04 100644 --- a/src/primaite/game/agent/observations/node_observations.py +++ b/src/primaite/game/agent/observations/node_observations.py @@ -1,9 +1,10 @@ from __future__ import annotations -from typing import Dict, List, TYPE_CHECKING +from typing import Dict, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType +from pydantic import model_validator from primaite import getLogger from primaite.game.agent.observations.firewall_observation import FirewallObservation @@ -28,33 +29,63 @@ class NodesObservation(AbstractObservation, identifier="NODES"): """List of configurations for router observations.""" firewalls: List[FirewallObservation.ConfigSchema] = [] """List of configurations for firewall observations.""" - num_services: int + num_services: Optional[int] = None """Number of services.""" - num_applications: int + num_applications: Optional[int] = None """Number of applications.""" - num_folders: int + num_folders: Optional[int] = None """Number of folders.""" - num_files: int + num_files: Optional[int] = None """Number of files.""" - num_nics: int + num_nics: Optional[int] = None """Number of network interface cards (NICs).""" - include_nmne: bool + include_nmne: Optional[bool] = None """Flag to include nmne.""" - include_num_access: bool + include_num_access: Optional[bool] = None """Flag to include the number of accesses.""" - num_ports: int + num_ports: Optional[int] = None """Number of ports.""" - ip_list: List[str] + ip_list: Optional[List[str]] = None """List of IP addresses for encoding ACLs.""" - wildcard_list: List[str] + wildcard_list: Optional[List[str]] = None """List of IP wildcards for encoding ACLs.""" - port_list: List[int] + port_list: Optional[List[int]] = None """List of ports for encoding ACLs.""" - protocol_list: List[str] + protocol_list: Optional[List[str]] = None """List of protocols for encoding ACLs.""" - num_rules: int + num_rules: Optional[int] = None """Number of rules ACL rules to show.""" + @model_validator(mode="after") + def force_optional_fields(self) -> NodesObservation.ConfigSchema: + """Check that options are specified only if they are needed for the nodes that are part of the config.""" + # check for hosts: + host_fields = ( + self.num_services, + self.num_applications, + self.num_folders, + self.num_files, + self.num_nics, + self.include_nmne, + self.include_num_access, + ) + router_fields = ( + self.num_ports, + self.ip_list, + self.wildcard_list, + self.port_list, + self.protocol_list, + self.num_rules, + ) + firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules) + if len(self.hosts) > 0 and any([x is None for x in host_fields]): + raise ValueError("Configuration error: Host observation options were not fully specified.") + if len(self.routers) > 0 and any([x is None for x in router_fields]): + raise ValueError("Configuration error: Router observation options were not fully specified.") + if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]): + raise ValueError("Configuration error: Firewall observation options were not fully specified.") + return self + def __init__( self, where: WhereType, diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py index 84311984..3703fa1c 100644 --- a/src/primaite/game/agent/observations/observation_manager.py +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Dict, List, TYPE_CHECKING +from typing import Any, Dict, List, Optional, TYPE_CHECKING from gymnasium import spaces from gymnasium.core import ObsType @@ -120,6 +120,30 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"): return cls(components=instances) +class NullObservation(AbstractObservation, identifier="NONE"): + """Empty observation that acts as a placeholder.""" + + def __init__(self) -> None: + """Initialise the empty observation.""" + self.default_observation = 0 + + def observe(self, state: Dict) -> Any: + """Simply return 0.""" + return 0 + + @property + def space(self) -> spaces.Space: + """Essentially empty space.""" + return spaces.Discrete(1) + + @classmethod + def from_config( + cls, config: NullObservation.ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = [] + ) -> NullObservation: + """Instantiate a NullObservation. Accepts parameters to comply with API.""" + return cls() + + class ObservationManager: """ Manage the observations of an Agent. @@ -156,7 +180,7 @@ class ObservationManager: return self.obs.space @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": + def from_config(cls, config: Optional[Dict], game: "PrimaiteGame") -> "ObservationManager": """ Create observation space from a config. @@ -168,6 +192,9 @@ class ObservationManager: :param game: Reference to the PrimaiteGame object that spawned this observation. :type game: PrimaiteGame """ + if config is None: + return cls(NullObservation()) + print(config) obs_type = config["type"] obs_class = AbstractObservation._registry[obs_type] observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)