#2417 Finalise parsing of observation space
This commit is contained in:
@@ -41,8 +41,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -91,8 +90,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -141,10 +139,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -177,102 +172,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
- type: NODES
|
||||
label: NODES # What is the dictionary key called
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
- hostname: database_server
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
routers:
|
||||
- hostname: router_1
|
||||
firewalls: {}
|
||||
|
||||
num_host_services: 1
|
||||
num_host_applications: 0
|
||||
num_host_folders: 1
|
||||
num_host_files: 1
|
||||
num_host_network_interfaces: 2
|
||||
num_router_ports: 4
|
||||
num_acl_rules: 10
|
||||
num_firewall_ports: 4
|
||||
firewalls_internal_inbound_acl: true
|
||||
firewalls_internal_outbound_acl: true
|
||||
firewalls_dmz_inbound_acl: true
|
||||
firewalls_dmz_outbound_acl: true
|
||||
firewalls_external_inbound_acl: true
|
||||
firewalls_external_outbound_acl: true
|
||||
- type: LINKS
|
||||
label: "LINKS"
|
||||
options:
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
# flake8: noqa
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
|
||||
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.game.agent.observations.link_observation import LinkObservation, LinksObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation, PortObservation
|
||||
from primaite.game.agent.observations.node_observations import NodesObservation
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, NullObservation, ObservationManager
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
|
||||
@@ -94,6 +94,8 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.include_num_access = include_num_access
|
||||
|
||||
# Ensure lists have lengths equal to specified counts by truncating or padding
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services:
|
||||
@@ -135,9 +137,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
"num_file_creations": 0,
|
||||
"num_file_deletions": 0,
|
||||
}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_file_creations"] = 0
|
||||
self.default_observation["num_file_deletions"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
@@ -160,8 +163,9 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
obs["NICS"] = {
|
||||
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
|
||||
}
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
return obs
|
||||
|
||||
@property
|
||||
@@ -180,9 +184,10 @@ class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"NICS": spaces.Dict(
|
||||
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
|
||||
),
|
||||
"num_file_creations": spaces.Discrete(4),
|
||||
"num_file_deletions": spaces.Discrete(4),
|
||||
}
|
||||
if self.include_num_access:
|
||||
shape["num_file_creations"] = spaces.Discrete(4)
|
||||
shape["num_file_deletions"] = spaces.Discrete(4)
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -132,7 +132,7 @@ class LinksObservation(AbstractObservation, identifier="LINKS"):
|
||||
:return: Gymnasium space representing the observation space for multiple links.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return {i + 1: l.space for i, l in enumerate(self.links)}
|
||||
return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import model_validator
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
||||
@@ -28,33 +29,63 @@ class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
"""List of configurations for router observations."""
|
||||
firewalls: List[FirewallObservation.ConfigSchema] = []
|
||||
"""List of configurations for firewall observations."""
|
||||
num_services: int
|
||||
num_services: Optional[int] = None
|
||||
"""Number of services."""
|
||||
num_applications: int
|
||||
num_applications: Optional[int] = None
|
||||
"""Number of applications."""
|
||||
num_folders: int
|
||||
num_folders: Optional[int] = None
|
||||
"""Number of folders."""
|
||||
num_files: int
|
||||
num_files: Optional[int] = None
|
||||
"""Number of files."""
|
||||
num_nics: int
|
||||
num_nics: Optional[int] = None
|
||||
"""Number of network interface cards (NICs)."""
|
||||
include_nmne: bool
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Flag to include nmne."""
|
||||
include_num_access: bool
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Flag to include the number of accesses."""
|
||||
num_ports: int
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of ports."""
|
||||
ip_list: List[str]
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: List[str]
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: List[int]
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: List[str]
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: int
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def force_optional_fields(self) -> NodesObservation.ConfigSchema:
|
||||
"""Check that options are specified only if they are needed for the nodes that are part of the config."""
|
||||
# check for hosts:
|
||||
host_fields = (
|
||||
self.num_services,
|
||||
self.num_applications,
|
||||
self.num_folders,
|
||||
self.num_files,
|
||||
self.num_nics,
|
||||
self.include_nmne,
|
||||
self.include_num_access,
|
||||
)
|
||||
router_fields = (
|
||||
self.num_ports,
|
||||
self.ip_list,
|
||||
self.wildcard_list,
|
||||
self.port_list,
|
||||
self.protocol_list,
|
||||
self.num_rules,
|
||||
)
|
||||
firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules)
|
||||
if len(self.hosts) > 0 and any([x is None for x in host_fields]):
|
||||
raise ValueError("Configuration error: Host observation options were not fully specified.")
|
||||
if len(self.routers) > 0 and any([x is None for x in router_fields]):
|
||||
raise ValueError("Configuration error: Router observation options were not fully specified.")
|
||||
if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]):
|
||||
raise ValueError("Configuration error: Firewall observation options were not fully specified.")
|
||||
return self
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
@@ -120,6 +120,30 @@ class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
return cls(components=instances)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation, identifier="NONE"):
|
||||
"""Empty observation that acts as a placeholder."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the empty observation."""
|
||||
self.default_observation = 0
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""Simply return 0."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Essentially empty space."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: NullObservation.ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> NullObservation:
|
||||
"""Instantiate a NullObservation. Accepts parameters to comply with API."""
|
||||
return cls()
|
||||
|
||||
|
||||
class ObservationManager:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
@@ -156,7 +180,7 @@ class ObservationManager:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
|
||||
def from_config(cls, config: Optional[Dict], game: "PrimaiteGame") -> "ObservationManager":
|
||||
"""
|
||||
Create observation space from a config.
|
||||
|
||||
@@ -168,6 +192,9 @@ class ObservationManager:
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
if config is None:
|
||||
return cls(NullObservation())
|
||||
print(config)
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)
|
||||
|
||||
Reference in New Issue
Block a user