#2417 Finalise parsing of observation space

This commit is contained in:
Marek Wolan
2024-03-31 23:20:48 +01:00
parent 62ebca8c08
commit 8da53db822
6 changed files with 167 additions and 126 deletions

View File

@@ -41,8 +41,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -91,8 +90,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -141,10 +139,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -177,102 +172,73 @@ agents:
type: ProxyAgent
observation_space:
- type: NODES
label: NODES # What is the dictionary key called
options:
hosts:
- hostname: domain_controller
- hostname: web_server
- hostname: database_server
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
routers:
- hostname: router_1
firewalls: {}
num_host_services: 1
num_host_applications: 0
num_host_folders: 1
num_host_files: 1
num_host_network_interfaces: 2
num_router_ports: 4
num_acl_rules: 10
num_firewall_ports: 4
firewalls_internal_inbound_acl: true
firewalls_internal_outbound_acl: true
firewalls_dmz_inbound_acl: true
firewalls_dmz_outbound_acl: true
firewalls_external_inbound_acl: true
firewalls_external_outbound_acl: true
- type: LINKS
label: "LINKS"
options:
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -0,0 +1,12 @@
# flake8: noqa
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
from primaite.game.agent.observations.firewall_observation import FirewallObservation
from primaite.game.agent.observations.host_observations import HostObservation
from primaite.game.agent.observations.link_observation import LinkObservation, LinksObservation
from primaite.game.agent.observations.nic_observations import NICObservation, PortObservation
from primaite.game.agent.observations.node_observations import NodesObservation
from primaite.game.agent.observations.observation_manager import NestedObservation, NullObservation, ObservationManager
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.router_observation import RouterObservation
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation

View File

@@ -94,6 +94,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"""
self.where: WhereType = where
self.include_num_access = include_num_access
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
while len(self.services) < num_services:
@@ -135,9 +137,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
"num_file_creations": 0,
"num_file_deletions": 0,
}
if self.include_num_access:
self.default_observation["num_file_creations"] = 0
self.default_observation["num_file_deletions"] = 0
def observe(self, state: Dict) -> ObsType:
"""
@@ -160,8 +163,9 @@ class HostObservation(AbstractObservation, identifier="HOST"):
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
}
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
return obs
@property
@@ -180,9 +184,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
"num_file_creations": spaces.Discrete(4),
"num_file_deletions": spaces.Discrete(4),
}
if self.include_num_access:
shape["num_file_creations"] = spaces.Discrete(4)
shape["num_file_deletions"] = spaces.Discrete(4)
return spaces.Dict(shape)
@classmethod

View File

@@ -132,7 +132,7 @@ class LinksObservation(AbstractObservation, identifier="LINKS"):
:return: Gymnasium space representing the observation space for multiple links.
:rtype: spaces.Space
"""
return {i + 1: l.space for i, l in enumerate(self.links)}
return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)})
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation:

View File

@@ -1,9 +1,10 @@
from __future__ import annotations
from typing import Dict, List, TYPE_CHECKING
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import model_validator
from primaite import getLogger
from primaite.game.agent.observations.firewall_observation import FirewallObservation
@@ -28,33 +29,63 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
"""List of configurations for router observations."""
firewalls: List[FirewallObservation.ConfigSchema] = []
"""List of configurations for firewall observations."""
num_services: int
num_services: Optional[int] = None
"""Number of services."""
num_applications: int
num_applications: Optional[int] = None
"""Number of applications."""
num_folders: int
num_folders: Optional[int] = None
"""Number of folders."""
num_files: int
num_files: Optional[int] = None
"""Number of files."""
num_nics: int
num_nics: Optional[int] = None
"""Number of network interface cards (NICs)."""
include_nmne: bool
include_nmne: Optional[bool] = None
"""Flag to include nmne."""
include_num_access: bool
include_num_access: Optional[bool] = None
"""Flag to include the number of accesses."""
num_ports: int
num_ports: Optional[int] = None
"""Number of ports."""
ip_list: List[str]
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: List[str]
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: List[int]
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: List[str]
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: int
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
@model_validator(mode="after")
def force_optional_fields(self) -> NodesObservation.ConfigSchema:
"""Check that options are specified only if they are needed for the nodes that are part of the config."""
# check for hosts:
host_fields = (
self.num_services,
self.num_applications,
self.num_folders,
self.num_files,
self.num_nics,
self.include_nmne,
self.include_num_access,
)
router_fields = (
self.num_ports,
self.ip_list,
self.wildcard_list,
self.port_list,
self.protocol_list,
self.num_rules,
)
firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules)
if len(self.hosts) > 0 and any([x is None for x in host_fields]):
raise ValueError("Configuration error: Host observation options were not fully specified.")
if len(self.routers) > 0 and any([x is None for x in router_fields]):
raise ValueError("Configuration error: Router observation options were not fully specified.")
if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]):
raise ValueError("Configuration error: Firewall observation options were not fully specified.")
return self
def __init__(
self,
where: WhereType,

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Dict, List, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
@@ -120,6 +120,30 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
return cls(components=instances)
class NullObservation(AbstractObservation, identifier="NONE"):
"""Empty observation that acts as a placeholder."""
def __init__(self) -> None:
"""Initialise the empty observation."""
self.default_observation = 0
def observe(self, state: Dict) -> Any:
"""Simply return 0."""
return 0
@property
def space(self) -> spaces.Space:
"""Essentially empty space."""
return spaces.Discrete(1)
@classmethod
def from_config(
cls, config: NullObservation.ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> NullObservation:
"""Instantiate a NullObservation. Accepts parameters to comply with API."""
return cls()
class ObservationManager:
"""
Manage the observations of an Agent.
@@ -156,7 +180,7 @@ class ObservationManager:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
def from_config(cls, config: Optional[Dict], game: "PrimaiteGame") -> "ObservationManager":
"""
Create observation space from a config.
@@ -168,6 +192,9 @@ class ObservationManager:
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
if config is None:
return cls(NullObservation())
print(config)
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)