Merged PR 321: CAOS 0.8 observations
## Summary * Remove the usecase-specific and agent-specific observation classes, replacing with a more flexible system * Add configuration schemas to every observation class * Add router, firewall, port, and application observation * Re-shape the dict structure of observations to make it adhere to CAOS 0.8 * Change existing configs to use the new structure * make host observation separate ## Test process existing and new unit tests as well as ad hoc notebooks ## Checklist - [ ] PR is linked to a **work item** - [ ] **acceptance criteria** of linked ticket are met - [ ] performed **self-review** of the code - [ ] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [ ] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2417
This commit is contained in:
@@ -120,7 +120,7 @@ SessionManager.
|
||||
- Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios.
|
||||
- **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules.
|
||||
- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them.
|
||||
- Integration of NMNE capturing functionality within the `NicObservation` class.
|
||||
- Integration of NMNE capturing functionality within the `NICObservation` class.
|
||||
- Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario
|
||||
|
||||
### Removed
|
||||
|
||||
@@ -73,7 +73,7 @@ Network Interface Classes
|
||||
- Malicious Network Events Monitoring:
|
||||
|
||||
* Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns.
|
||||
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies.
|
||||
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NICObservation`` to classify and record network anomalies.
|
||||
* Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms.
|
||||
|
||||
**WiredNetworkInterface (Connection Type Layer)**
|
||||
|
||||
@@ -41,8 +41,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -91,8 +90,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -141,10 +139,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -177,61 +172,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -40,8 +40,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -90,8 +89,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -140,10 +138,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -179,61 +174,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -742,61 +749,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
# flake8: noqa
|
||||
# Pre-import all the observations when we load up the observations module so that they can be resolved by the parser.
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
|
||||
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.game.agent.observations.link_observation import LinkObservation, LinksObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation, PortObservation
|
||||
from primaite.game.agent.observations.node_observations import NodesObservation
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, NullObservation, ObservationManager
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
|
||||
# fmt: off
|
||||
__all__ = [
|
||||
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
|
||||
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
|
||||
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
|
||||
# fmt: on
|
||||
|
||||
189
src/primaite/game/agent/observations/acl_observation.py
Normal file
189
src/primaite/game/agent/observations/acl_observation.py
Normal file
@@ -0,0 +1,189 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ACLObservation(AbstractObservation, identifier="ACL"):
|
||||
"""ACL observation, provides information about access control lists within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ACLObservation."""
|
||||
|
||||
ip_list: Optional[List[IPv4Address]] = None
|
||||
"""List of IP addresses."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of wildcard strings."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of port numbers."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocol names."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of ACL rules."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
num_rules: int,
|
||||
ip_list: List[IPv4Address],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
) -> None:
|
||||
"""
|
||||
Initialise an ACL observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
|
||||
:type where: WhereType
|
||||
:param num_rules: Number of ACL rules.
|
||||
:type num_rules: int
|
||||
:param ip_list: List of IP addresses.
|
||||
:type ip_list: List[IPv4Address]
|
||||
:param wildcard_list: List of wildcard strings.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param protocol_list: List of protocol names.
|
||||
:type protocol_list: List[str]
|
||||
"""
|
||||
self.where = where
|
||||
self.num_rules: int = num_rules
|
||||
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
|
||||
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
|
||||
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
|
||||
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
"source_port_id": 0,
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing ACL rules.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_ip_id": 0,
|
||||
"source_wildcard_id": 0,
|
||||
"source_port_id": 0,
|
||||
"dest_ip_id": 0,
|
||||
"dest_wildcard_id": 0,
|
||||
"dest_port_id": 0,
|
||||
"protocol_id": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip]
|
||||
src_wildcard = rule_state["src_wildcard_mask"]
|
||||
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
|
||||
dst_wildcard = rule_state["dst_wildcard_mask"]
|
||||
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = self.port_to_id.get(src_port, 1)
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = self.port_to_id.get(dst_port, 1)
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = self.protocol_to_id.get(protocol, 1)
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_ip_id": src_node_id,
|
||||
"source_wildcard_id": src_wildcard_id,
|
||||
"source_port_id": src_port_id,
|
||||
"dest_ip_id": dst_node_id,
|
||||
"dest_wildcard_id": dst_wildcard_id,
|
||||
"dest_port_id": dst_port_id,
|
||||
"protocol_id": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for ACL rules.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
|
||||
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
|
||||
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
|
||||
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> ACLObservation:
|
||||
"""
|
||||
Create an ACL observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the ACL observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this ACL's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed ACL observation instance.
|
||||
:rtype: ACLObservation
|
||||
"""
|
||||
return cls(
|
||||
where=parent_where + ["acl", "acl"],
|
||||
num_rules=config.num_rules,
|
||||
ip_list=config.ip_list,
|
||||
wildcard_list=config.wildcard_list,
|
||||
port_list=config.port_list,
|
||||
protocol_list=config.protocol_list,
|
||||
)
|
||||
@@ -1,188 +0,0 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.game.agent.observations.observations import (
|
||||
AbstractObservation,
|
||||
AclObservation,
|
||||
ICSObservation,
|
||||
LinkObservation,
|
||||
NullObservation,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class UC2BlueObservation(AbstractObservation):
|
||||
"""Container for all observations used by the blue agent in UC2.
|
||||
|
||||
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
|
||||
for the purpose of compiling several observation components.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialise UC2 blue observation.
|
||||
|
||||
:param nodes: List of node observations
|
||||
:type nodes: List[NodeObservation]
|
||||
:param links: List of link observations
|
||||
:type links: List[LinkObservation]
|
||||
:param acl: The Access Control List observation
|
||||
:type acl: AclObservation
|
||||
:param ics: The ICS observation
|
||||
:type ics: ICSObservation
|
||||
:param where: Where in the simulation state dict to find information. Not used in this particular observation
|
||||
because it only compiles other observations and doesn't contribute any new information, defaults to None
|
||||
:type where: Optional[List[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: AclObservation = acl
|
||||
self.ics: ICSObservation = ics
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
|
||||
"ACL": self.acl.default_observation,
|
||||
"ICS": self.ics.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
obs["ICS"] = self.ics.observe(state)
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
links, ACL and ICS observations.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed UC2 blue observation
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
|
||||
num_services_per_node = config["num_services_per_node"]
|
||||
num_folders_per_node = config["num_folders_per_node"]
|
||||
num_files_per_folder = config["num_files_per_folder"]
|
||||
num_nics_per_node = config["num_nics_per_node"]
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
game=game,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
for n in node_configs
|
||||
]
|
||||
|
||||
link_configs = config["links"]
|
||||
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
|
||||
|
||||
acl_config = config["acl"]
|
||||
acl = AclObservation.from_config(config=acl_config, game=game)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, game=game)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
|
||||
class UC2RedObservation(AbstractObservation):
|
||||
"""Container for all observations used by the red agent in UC2."""
|
||||
|
||||
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 red observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
class UC2GreenObservation(NullObservation):
|
||||
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
|
||||
|
||||
pass
|
||||
@@ -1,126 +1,170 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Iterable, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
"""Observation of a file on a node in the network."""
|
||||
class FileObservation(AbstractObservation, identifier="FILE"):
|
||||
"""File observation, provides status information about a file within the simulation environment."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FileObservation."""
|
||||
|
||||
file_name: str
|
||||
"""Name of the file, used for querying simulation state dictionary."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to the file in the observation."""
|
||||
|
||||
def __init__(self, where: WhereType, include_num_access: bool) -> None:
|
||||
"""
|
||||
Initialise file observation.
|
||||
Initialise a file observation instance.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this file.
|
||||
A typical location for a file might be
|
||||
['network', 'nodes', <node_hostname>, 'file_system', 'folder', <folder_name>, 'files', <file_name>].
|
||||
:type where: WhereType
|
||||
:param include_num_access: Whether to include the number of accesses to the file in the observation.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
self.where: WhereType = where
|
||||
self.include_num_access: bool = include_num_access
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
self.default_observation: ObsType = {"health_status": 0}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_access"] = 0
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
|
||||
def _categorise_num_access(self, num_access: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_access > self.high_threshold:
|
||||
return 3
|
||||
elif num_access > self.med_threshold:
|
||||
return 2
|
||||
elif num_access > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing the health status of the file and optionally the number of accesses.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["visible_status"]}
|
||||
obs = {"health_status": file_state["visible_status"]}
|
||||
if self.include_num_access:
|
||||
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:return: Gymnasium space representing the observation space for file status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
space = {"health_status": spaces.Discrete(6)}
|
||||
if self.include_num_access:
|
||||
space["num_access"] = spaces.Discrete(4)
|
||||
return spaces.Dict(space)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
:type config: Dict
|
||||
:param game: _description_
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: _description_, defaults to None
|
||||
:type parent_where: _type_, optional
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FileObservation:
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config["file_name"]])
|
||||
Create a file observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the file observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this file's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed file observation instance.
|
||||
:rtype: FileObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
"""Folder observation, including files inside of the folder."""
|
||||
class FolderObservation(AbstractObservation, identifier="FOLDER"):
|
||||
"""Folder observation, provides status information about a folder within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FolderObservation."""
|
||||
|
||||
folder_name: str
|
||||
"""Name of the folder, used for querying simulation state dictionary."""
|
||||
files: List[FileObservation.ConfigSchema] = []
|
||||
"""List of file configurations within the folder."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of spaces for file observations in this folder."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether files in this folder should include the number of accesses in their observation."""
|
||||
|
||||
def __init__(
|
||||
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
|
||||
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
|
||||
) -> None:
|
||||
"""Initialise folder Observation, including files inside the folder.
|
||||
"""
|
||||
Initialise a folder observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param max_files: As size of the space must remain static, define max files that can be in this folder
|
||||
, defaults to 5
|
||||
:type max_files: int, optional
|
||||
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
|
||||
that even if new files are created, the existing files will always occupy the same space in the observation
|
||||
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
|
||||
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
|
||||
name, it will take the position defined in this dict. Defaults to {}
|
||||
:type file_positions: Dict[int, str], optional
|
||||
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].
|
||||
:type where: WhereType
|
||||
:param files: List of file observation instances within the folder.
|
||||
:type files: Iterable[FileObservation]
|
||||
:param num_files: Number of files expected in the folder.
|
||||
:type num_files: int
|
||||
:param include_num_access: Whether to include the number of accesses to files in the observation.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.where: WhereType = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files_per_folder:
|
||||
self.files.append(FileObservation())
|
||||
while len(self.files) > num_files_per_folder:
|
||||
while len(self.files) < num_files:
|
||||
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
|
||||
while len(self.files) > num_files:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
if self.files:
|
||||
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the health status of the folder and status of files within the folder.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
@@ -130,48 +174,42 @@ class FolderObservation(AbstractObservation):
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
if self.files:
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:return: Gymnasium space representing the observation space for folder status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"health_status": spaces.Discrete(6),
|
||||
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
|
||||
}
|
||||
)
|
||||
shape = {"health_status": spaces.Discrete(6)}
|
||||
if self.files:
|
||||
shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)})
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:
|
||||
"""
|
||||
Create a folder observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
|
||||
folder and the files inside of it.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param config: Configuration schema containing the necessary information for the folder observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_hostname>,'file_system']
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed folder observation
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed folder observation instance.
|
||||
:rtype: FolderObservation
|
||||
"""
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
where = parent_where + ["folders", config.folder_name]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
|
||||
# pass down shared/common config items
|
||||
for file_config in config.files:
|
||||
file_config.include_num_access = config.include_num_access
|
||||
|
||||
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in config.files]
|
||||
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)
|
||||
|
||||
224
src/primaite/game/agent/observations/firewall_observation.py
Normal file
224
src/primaite/game/agent/observations/firewall_observation.py
Normal file
@@ -0,0 +1,224 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
|
||||
"""Firewall observation, provides status information about a firewall within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for FirewallObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the firewall node, used for querying simulation state dictionary."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
ip_list: List[str],
|
||||
wildcard_list: List[str],
|
||||
port_list: List[int],
|
||||
protocol_list: List[str],
|
||||
num_rules: int,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a firewall observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this firewall.
|
||||
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
|
||||
:type where: WhereType
|
||||
:param ip_list: List of IP addresses.
|
||||
:type ip_list: List[str]
|
||||
:param wildcard_list: List of wildcard rules.
|
||||
:type wildcard_list: List[str]
|
||||
:param port_list: List of port numbers.
|
||||
:type port_list: List[int]
|
||||
:param protocol_list: List of protocol types.
|
||||
:type protocol_list: List[str]
|
||||
:param num_rules: Number of rules configured in the firewall.
|
||||
:type num_rules: int
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.ports: List[PortObservation] = [
|
||||
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
|
||||
]
|
||||
# TODO: check what the port nums are for firewall.
|
||||
|
||||
self.internal_inbound_acl = ACLObservation(
|
||||
where=self.where + ["internal_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.internal_outbound_acl = ACLObservation(
|
||||
where=self.where + ["internal_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_inbound_acl = ACLObservation(
|
||||
where=self.where + ["dmz_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.dmz_outbound_acl = ACLObservation(
|
||||
where=self.where + ["dmz_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_inbound_acl = ACLObservation(
|
||||
where=self.where + ["external_inbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
self.external_outbound_acl = ACLObservation(
|
||||
where=self.where + ["external_outbound_acl", "acl"],
|
||||
num_rules=num_rules,
|
||||
ip_list=ip_list,
|
||||
wildcard_list=wildcard_list,
|
||||
port_list=port_list,
|
||||
protocol_list=protocol_list,
|
||||
)
|
||||
|
||||
self.default_observation = {
|
||||
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.internal_outbound_acl.default_observation,
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.dmz_outbound_acl.default_observation,
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.default_observation,
|
||||
"OUTBOUND": self.external_outbound_acl.default_observation,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
obs = {
|
||||
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
|
||||
"ACL": {
|
||||
"INTERNAL": {
|
||||
"INBOUND": self.internal_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.internal_outbound_acl.observe(state),
|
||||
},
|
||||
"DMZ": {
|
||||
"INBOUND": self.dmz_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.dmz_outbound_acl.observe(state),
|
||||
},
|
||||
"EXTERNAL": {
|
||||
"INBOUND": self.external_inbound_acl.observe(state),
|
||||
"OUTBOUND": self.external_outbound_acl.observe(state),
|
||||
},
|
||||
},
|
||||
}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for firewall status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
|
||||
"ACL": spaces.Dict(
|
||||
{
|
||||
"INTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.internal_inbound_acl.space,
|
||||
"OUTBOUND": self.internal_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"DMZ": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.dmz_inbound_acl.space,
|
||||
"OUTBOUND": self.dmz_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
"EXTERNAL": spaces.Dict(
|
||||
{
|
||||
"INBOUND": self.external_inbound_acl.space,
|
||||
"OUTBOUND": self.external_outbound_acl.space,
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> FirewallObservation:
|
||||
"""
|
||||
Create a firewall observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the firewall observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this firewall's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <firewall_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed firewall observation instance.
|
||||
:rtype: FirewallObservation
|
||||
"""
|
||||
return cls(
|
||||
where=parent_where + ["nodes", config.hostname],
|
||||
ip_list=config.ip_list,
|
||||
wildcard_list=config.wildcard_list,
|
||||
port_list=config.port_list,
|
||||
protocol_list=config.protocol_list,
|
||||
num_rules=config.num_rules,
|
||||
)
|
||||
248
src/primaite/game/agent/observations/host_observations.py
Normal file
248
src/primaite/game/agent/observations/host_observations.py
Normal file
@@ -0,0 +1,248 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class HostObservation(AbstractObservation, identifier="HOST"):
|
||||
"""Host observation, provides status information about a host within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for HostObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the host, used for querying simulation state dictionary."""
|
||||
services: List[ServiceObservation.ConfigSchema] = []
|
||||
"""List of services to observe on the host."""
|
||||
applications: List[ApplicationObservation.ConfigSchema] = []
|
||||
"""List of applications to observe on the host."""
|
||||
folders: List[FolderObservation.ConfigSchema] = []
|
||||
"""List of folders to observe on the host."""
|
||||
network_interfaces: List[NICObservation.ConfigSchema] = []
|
||||
"""List of network interfaces to observe on the host."""
|
||||
num_services: Optional[int] = None
|
||||
"""Number of spaces for service observations on this host."""
|
||||
num_applications: Optional[int] = None
|
||||
"""Number of spaces for application observations on this host."""
|
||||
num_folders: Optional[int] = None
|
||||
"""Number of spaces for folder observations on this host."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of spaces for file observations on this host."""
|
||||
num_nics: Optional[int] = None
|
||||
"""Number of spaces for network interface observations on this host."""
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Whether network interface observations should include number of malicious network events."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Whether to include the number of accesses to files observations on this host."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
services: List[ServiceObservation],
|
||||
applications: List[ApplicationObservation],
|
||||
folders: List[FolderObservation],
|
||||
network_interfaces: List[NICObservation],
|
||||
num_services: int,
|
||||
num_applications: int,
|
||||
num_folders: int,
|
||||
num_files: int,
|
||||
num_nics: int,
|
||||
include_nmne: bool,
|
||||
include_num_access: bool,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a host observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this host.
|
||||
A typical location for a host might be ['network', 'nodes', <hostname>].
|
||||
:type where: WhereType
|
||||
:param services: List of service observations on the host.
|
||||
:type services: List[ServiceObservation]
|
||||
:param applications: List of application observations on the host.
|
||||
:type applications: List[ApplicationObservation]
|
||||
:param folders: List of folder observations on the host.
|
||||
:type folders: List[FolderObservation]
|
||||
:param network_interfaces: List of network interface observations on the host.
|
||||
:type network_interfaces: List[NICObservation]
|
||||
:param num_services: Number of services to observe.
|
||||
:type num_services: int
|
||||
:param num_applications: Number of applications to observe.
|
||||
:type num_applications: int
|
||||
:param num_folders: Number of folders to observe.
|
||||
:type num_folders: int
|
||||
:param num_files: Number of files.
|
||||
:type num_files: int
|
||||
:param num_nics: Number of network interfaces.
|
||||
:type num_nics: int
|
||||
:param include_nmne: Flag to include network metrics and errors.
|
||||
:type include_nmne: bool
|
||||
:param include_num_access: Flag to include the number of accesses to files.
|
||||
:type include_num_access: bool
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
|
||||
self.include_num_access = include_num_access
|
||||
|
||||
# Ensure lists have lengths equal to specified counts by truncating or padding
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services:
|
||||
self.services.append(ServiceObservation(where=None))
|
||||
while len(self.services) > num_services:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.applications: List[ApplicationObservation] = applications
|
||||
while len(self.applications) < num_applications:
|
||||
self.applications.append(ApplicationObservation(where=None))
|
||||
while len(self.applications) > num_applications:
|
||||
truncated_application = self.applications.pop()
|
||||
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
while len(self.folders) < num_folders:
|
||||
self.folders.append(
|
||||
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
|
||||
)
|
||||
while len(self.folders) > num_folders:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.nics: List[NICObservation] = network_interfaces
|
||||
while len(self.nics) < num_nics:
|
||||
self.nics.append(NICObservation(where=None, include_nmne=include_nmne))
|
||||
while len(self.nics) > num_nics:
|
||||
truncated_nic = self.nics.pop()
|
||||
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation: ObsType = {
|
||||
"operating_status": 0,
|
||||
}
|
||||
if self.services:
|
||||
self.default_observation["SERVICES"] = {i + 1: s.default_observation for i, s in enumerate(self.services)}
|
||||
if self.applications:
|
||||
self.default_observation["APPLICATIONS"] = {
|
||||
i + 1: a.default_observation for i, a in enumerate(self.applications)
|
||||
}
|
||||
if self.folders:
|
||||
self.default_observation["FOLDERS"] = {i + 1: f.default_observation for i, f in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
self.default_observation["NICS"] = {i + 1: n.default_observation for i, n in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
self.default_observation["num_file_creations"] = 0
|
||||
self.default_observation["num_file_deletions"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status information about the host.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
if self.services:
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
if self.applications:
|
||||
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
|
||||
if self.folders:
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
if self.nics:
|
||||
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
|
||||
if self.include_num_access:
|
||||
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
|
||||
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for host status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
shape = {
|
||||
"operating_status": spaces.Discrete(5),
|
||||
}
|
||||
if self.services:
|
||||
shape["SERVICES"] = spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)})
|
||||
if self.applications:
|
||||
shape["APPLICATIONS"] = spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)})
|
||||
if self.folders:
|
||||
shape["FOLDERS"] = spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)})
|
||||
if self.nics:
|
||||
shape["NICS"] = spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)})
|
||||
if self.include_num_access:
|
||||
shape["num_file_creations"] = spaces.Discrete(4)
|
||||
shape["num_file_deletions"] = spaces.Discrete(4)
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> HostObservation:
|
||||
"""
|
||||
Create a host observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the host observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this host.
|
||||
A typical location might be ['network', 'nodes', <hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed host observation instance.
|
||||
:rtype: HostObservation
|
||||
"""
|
||||
if parent_where == []:
|
||||
where = ["network", "nodes", config.hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
|
||||
# Pass down shared/common config items
|
||||
for folder_config in config.folders:
|
||||
folder_config.include_num_access = config.include_num_access
|
||||
folder_config.num_files = config.num_files
|
||||
for nic_config in config.network_interfaces:
|
||||
nic_config.include_nmne = config.include_nmne
|
||||
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in config.services]
|
||||
applications = [
|
||||
ApplicationObservation.from_config(config=c, game=game, parent_where=where) for c in config.applications
|
||||
]
|
||||
folders = [FolderObservation.from_config(config=c, game=game, parent_where=where) for c in config.folders]
|
||||
nics = [NICObservation.from_config(config=c, game=game, parent_where=where) for c in config.network_interfaces]
|
||||
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
applications=applications,
|
||||
folders=folders,
|
||||
network_interfaces=nics,
|
||||
num_services=config.num_services,
|
||||
num_applications=config.num_applications,
|
||||
num_folders=config.num_folders,
|
||||
num_files=config.num_files,
|
||||
num_nics=config.num_nics,
|
||||
include_nmne=config.include_nmne,
|
||||
include_num_access=config.include_num_access,
|
||||
)
|
||||
155
src/primaite/game/agent/observations/link_observation.py
Normal file
155
src/primaite/game/agent/observations/link_observation.py
Normal file
@@ -0,0 +1,155 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation, identifier="LINK"):
|
||||
"""Link observation, providing information about a specific link within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinkObservation."""
|
||||
|
||||
link_reference: str
|
||||
"""Reference identifier for the link."""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise a link observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this link.
|
||||
A typical location for a link might be ['network', 'links', <link_reference>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about the link.
|
||||
:rtype: Any
|
||||
"""
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for link status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinkObservation:
|
||||
"""
|
||||
Create a link observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the link observation.
|
||||
:type config: ConfigSchema
|
||||
:param game: The PrimaiteGame instance.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this link.
|
||||
A typical location might be ['network', 'links', <link_reference>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed link observation instance.
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
link_reference = game.ref_map_links[config.link_reference]
|
||||
if parent_where == []:
|
||||
where = ["network", "links", link_reference]
|
||||
else:
|
||||
where = parent_where + ["links", link_reference]
|
||||
return cls(where=where)
|
||||
|
||||
|
||||
class LinksObservation(AbstractObservation, identifier="LINKS"):
|
||||
"""Collection of link observations representing multiple links within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for LinksObservation."""
|
||||
|
||||
link_references: List[str]
|
||||
"""List of reference identifiers for the links."""
|
||||
|
||||
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
|
||||
"""
|
||||
Initialise a links observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for these links.
|
||||
A typical location for links might be ['network', 'links'].
|
||||
:type where: WhereType
|
||||
:param links: List of link observations.
|
||||
:type links: List[LinkObservation]
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.links: List[LinkObservation] = links
|
||||
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing information about multiple links.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for multiple links.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation:
|
||||
"""
|
||||
Create a links observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the links observation.
|
||||
:type config: ConfigSchema
|
||||
:param game: The PrimaiteGame instance.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about these links.
|
||||
A typical location might be ['network'].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed links observation instance.
|
||||
:rtype: LinksObservation
|
||||
"""
|
||||
where = parent_where + ["network"]
|
||||
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
|
||||
links = [LinkObservation.from_config(c, game=game, parent_where=where) for c in link_cfgs]
|
||||
return cls(where=where, links=links)
|
||||
@@ -1,97 +1,56 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
|
||||
"""Status information about a network interface within the simulation environment."""
|
||||
|
||||
low_nmne_threshold: int = 0
|
||||
"""The minimum number of malicious network events to be considered low."""
|
||||
med_nmne_threshold: int = 5
|
||||
"""The minimum number of malicious network events to be considered medium."""
|
||||
high_nmne_threshold: int = 10
|
||||
"""The minimum number of malicious network events to be considered high."""
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NICObservation."""
|
||||
|
||||
global CAPTURE_NMNE
|
||||
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
data = {"nic_status": 0}
|
||||
if CAPTURE_NMNE:
|
||||
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
|
||||
|
||||
return data
|
||||
nic_num: int
|
||||
"""Number of the network interface."""
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Whether to include number of malicious network events (NMNE) in the observation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
low_nmne_threshold: Optional[int] = 0,
|
||||
med_nmne_threshold: Optional[int] = 5,
|
||||
high_nmne_threshold: Optional[int] = 10,
|
||||
where: WhereType,
|
||||
include_nmne: bool,
|
||||
) -> None:
|
||||
"""Initialise NIC observation.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_hostname>,'NICs',<nic_number>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialise a network interface observation instance.
|
||||
|
||||
global CAPTURE_NMNE
|
||||
if CAPTURE_NMNE:
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this interface.
|
||||
A typical location for a network interface might be
|
||||
['network', 'nodes', <node_hostname>, 'NICs', <nic_num>].
|
||||
:type where: WhereType
|
||||
:param include_nmne: Flag to determine whether to include NMNE information in the observation.
|
||||
:type include_nmne: bool
|
||||
"""
|
||||
self.where = where
|
||||
self.include_nmne: bool = include_nmne
|
||||
|
||||
self.default_observation: ObsType = {"nic_status": 0}
|
||||
if self.include_nmne:
|
||||
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
|
||||
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
|
||||
self._validate_nmne_categories(
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
|
||||
def _validate_nmne_categories(
|
||||
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
|
||||
):
|
||||
"""
|
||||
Validates the nmne threshold config.
|
||||
|
||||
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
|
||||
|
||||
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
|
||||
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
|
||||
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
|
||||
"""
|
||||
if high_nmne_threshold <= med_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
|
||||
f"than medium nmne count ({med_nmne_threshold})"
|
||||
)
|
||||
|
||||
if med_nmne_threshold <= low_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
|
||||
f"than low nmne count ({low_nmne_threshold})"
|
||||
)
|
||||
|
||||
self.high_nmne_threshold = high_nmne_threshold
|
||||
self.med_nmne_threshold = med_nmne_threshold
|
||||
self.low_nmne_threshold = low_nmne_threshold
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_nmne_threshold = 10
|
||||
self.med_nmne_threshold = 5
|
||||
self.low_nmne_threshold = 0
|
||||
|
||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||
"""
|
||||
@@ -116,73 +75,120 @@ class NicObservation(AbstractObservation):
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status of the network interface and optionally NMNE information.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if CAPTURE_NMNE:
|
||||
obs_dict.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs_dict
|
||||
|
||||
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if self.include_nmne:
|
||||
obs.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for network interface status and NMNE information.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
if CAPTURE_NMNE:
|
||||
if self.include_nmne:
|
||||
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
|
||||
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NICObservation:
|
||||
"""
|
||||
low_nmne_threshold = None
|
||||
med_nmne_threshold = None
|
||||
high_nmne_threshold = None
|
||||
Create a network interface observation from a configuration schema.
|
||||
|
||||
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
|
||||
threshold = game.options.thresholds["nmne"]
|
||||
:param config: Configuration schema containing the necessary information for the network interface observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed network interface observation instance.
|
||||
:rtype: NICObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne)
|
||||
|
||||
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
|
||||
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
|
||||
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
|
||||
|
||||
return cls(
|
||||
where=parent_where + ["NICs", config["nic_num"]],
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
class PortObservation(AbstractObservation, identifier="PORT"):
|
||||
"""Port observation, provides status information about a network port within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for PortObservation."""
|
||||
|
||||
port_id: int
|
||||
"""Identifier of the port, used for querying simulation state dictionary."""
|
||||
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
Initialise a port observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this port.
|
||||
A typical location for a port might be ['network', 'nodes', <node_hostname>, 'NICs', <port_id>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation: ObsType = {"operating_status": 0}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the operating status of the port.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
port_state = access_from_nested_dict(state, self.where)
|
||||
if port_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"operating_status": 1 if port_state["enabled"] else 2}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for port status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(3)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> PortObservation:
|
||||
"""
|
||||
Create a port observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the port observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this port's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed port observation instance.
|
||||
:rtype: PortObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["NICs", config.port_id])
|
||||
|
||||
@@ -1,200 +1,218 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import model_validator
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.software_observation import ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class NodeObservation(AbstractObservation):
|
||||
"""Observation of a node in the network. Includes services, folders and NICs."""
|
||||
class NodesObservation(AbstractObservation, identifier="NODES"):
|
||||
"""Nodes observation, provides status information about nodes within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NodesObservation."""
|
||||
|
||||
hosts: List[HostObservation.ConfigSchema] = []
|
||||
"""List of configurations for host observations."""
|
||||
routers: List[RouterObservation.ConfigSchema] = []
|
||||
"""List of configurations for router observations."""
|
||||
firewalls: List[FirewallObservation.ConfigSchema] = []
|
||||
"""List of configurations for firewall observations."""
|
||||
num_services: Optional[int] = None
|
||||
"""Number of services."""
|
||||
num_applications: Optional[int] = None
|
||||
"""Number of applications."""
|
||||
num_folders: Optional[int] = None
|
||||
"""Number of folders."""
|
||||
num_files: Optional[int] = None
|
||||
"""Number of files."""
|
||||
num_nics: Optional[int] = None
|
||||
"""Number of network interface cards (NICs)."""
|
||||
include_nmne: Optional[bool] = None
|
||||
"""Flag to include nmne."""
|
||||
include_num_access: Optional[bool] = None
|
||||
"""Flag to include the number of accesses."""
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of ports."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def force_optional_fields(self) -> NodesObservation.ConfigSchema:
|
||||
"""Check that options are specified only if they are needed for the nodes that are part of the config."""
|
||||
# check for hosts:
|
||||
host_fields = (
|
||||
self.num_services,
|
||||
self.num_applications,
|
||||
self.num_folders,
|
||||
self.num_files,
|
||||
self.num_nics,
|
||||
self.include_nmne,
|
||||
self.include_num_access,
|
||||
)
|
||||
router_fields = (
|
||||
self.num_ports,
|
||||
self.ip_list,
|
||||
self.wildcard_list,
|
||||
self.port_list,
|
||||
self.protocol_list,
|
||||
self.num_rules,
|
||||
)
|
||||
firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules)
|
||||
if len(self.hosts) > 0 and any([x is None for x in host_fields]):
|
||||
raise ValueError("Configuration error: Host observation options were not fully specified.")
|
||||
if len(self.routers) > 0 and any([x is None for x in router_fields]):
|
||||
raise ValueError("Configuration error: Router observation options were not fully specified.")
|
||||
if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]):
|
||||
raise ValueError("Configuration error: Firewall observation options were not fully specified.")
|
||||
return self
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
network_interfaces: List[NicObservation] = [],
|
||||
logon_status: bool = False,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
where: WhereType,
|
||||
hosts: List[HostObservation],
|
||||
routers: List[RouterObservation],
|
||||
firewalls: List[FirewallObservation],
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
Initialise a nodes observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary for find relevant information for this observation.
|
||||
A typical location for a node looks like this:
|
||||
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
|
||||
:type where: List[str], optional
|
||||
:param services: Mapping between position in observation space and service name, defaults to {}
|
||||
:type services: Dict[int,str], optional
|
||||
:param max_services: Max number of services that can be presented in observation space for this node
|
||||
, defaults to 2
|
||||
:type max_services: int, optional
|
||||
:param folders: Mapping between position in observation space and folder name, defaults to {}
|
||||
:type folders: Dict[int,str], optional
|
||||
:param max_folders: Max number of folders in this node's obs space, defaults to 2
|
||||
:type max_folders: int, optional
|
||||
:param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {}
|
||||
:type network_interfaces: Dict[int,str], optional
|
||||
:param max_nics: Max number of network interfaces in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for nodes.
|
||||
A typical location for nodes might be ['network', 'nodes'].
|
||||
:type where: WhereType
|
||||
:param hosts: List of host observations.
|
||||
:type hosts: List[HostObservation]
|
||||
:param routers: List of router observations.
|
||||
:type routers: List[RouterObservation]
|
||||
:param firewalls: List of firewall observations.
|
||||
:type firewalls: List[FirewallObservation]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.where: WhereType = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services_per_node:
|
||||
# add empty service observation without `where` parameter so it always returns default (blank) observation
|
||||
self.services.append(ServiceObservation())
|
||||
while len(self.services) > num_services_per_node:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warning(msg)
|
||||
# truncate service list
|
||||
self.hosts: List[HostObservation] = hosts
|
||||
self.routers: List[RouterObservation] = routers
|
||||
self.firewalls: List[FirewallObservation] = firewalls
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
# add empty folder observation without `where` parameter that will always return default (blank) observations
|
||||
while len(self.folders) < num_folders_per_node:
|
||||
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
|
||||
while len(self.folders) > num_folders_per_node:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.network_interfaces: List[NicObservation] = network_interfaces
|
||||
while len(self.network_interfaces) < num_nics_per_node:
|
||||
self.network_interfaces.append(NicObservation())
|
||||
while len(self.network_interfaces) > num_nics_per_node:
|
||||
truncated_nic = self.network_interfaces.pop()
|
||||
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
self.default_observation = {
|
||||
**{f"HOST{i}": host.default_observation for i, host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}": router.default_observation for i, router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}": firewall.default_observation for i, firewall in enumerate(self.firewalls)},
|
||||
}
|
||||
if self.logon_status:
|
||||
self.default_observation["logon_status"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing status information about nodes.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {
|
||||
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
|
||||
obs = {
|
||||
**{f"HOST{i}": host.observe(state) for i, host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}": router.observe(state) for i, router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}": firewall.observe(state) for i, firewall in enumerate(self.firewalls)},
|
||||
}
|
||||
|
||||
if self.logon_status:
|
||||
obs["logon_status"] = 0
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
space_shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict(
|
||||
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
|
||||
),
|
||||
}
|
||||
if self.logon_status:
|
||||
space_shape["logon_status"] = spaces.Discrete(3)
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
return spaces.Dict(space_shape)
|
||||
:return: Gymnasium space representing the observation space for nodes.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
space = spaces.Dict(
|
||||
{
|
||||
**{f"HOST{i}": host.space for i, host in enumerate(self.hosts)},
|
||||
**{f"ROUTER{i}": router.space for i, router in enumerate(self.routers)},
|
||||
**{f"FIREWALL{i}": firewall.space for i, firewall in enumerate(self.firewalls)},
|
||||
}
|
||||
)
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
game: "PrimaiteGame",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> "NodeObservation":
|
||||
"""Create node observation from a config. Also creates child service, folder and NIC observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this node observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
|
||||
network. A typical location for it would be: ['network',]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_services_per_node: int, optional
|
||||
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_folders_per_node: int, optional
|
||||
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NodesObservation:
|
||||
"""
|
||||
node_hostname = config["node_hostname"]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_hostname]
|
||||
Create a nodes observation from a configuration schema.
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
|
||||
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
|
||||
network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
folders=folders,
|
||||
network_interfaces=network_interfaces,
|
||||
logon_status=logon_status,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
:param config: Configuration schema containing the necessary information for nodes observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about nodes.
|
||||
A typical location for nodes might be ['network', 'nodes'].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed nodes observation instance.
|
||||
:rtype: NodesObservation
|
||||
"""
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes"]
|
||||
else:
|
||||
where = parent_where + ["nodes"]
|
||||
|
||||
for host_config in config.hosts:
|
||||
if host_config.num_services is None:
|
||||
host_config.num_services = config.num_services
|
||||
if host_config.num_applications is None:
|
||||
host_config.num_applications = config.num_applications
|
||||
if host_config.num_folders is None:
|
||||
host_config.num_folders = config.num_folders
|
||||
if host_config.num_files is None:
|
||||
host_config.num_files = config.num_files
|
||||
if host_config.num_nics is None:
|
||||
host_config.num_nics = config.num_nics
|
||||
if host_config.include_nmne is None:
|
||||
host_config.include_nmne = config.include_nmne
|
||||
if host_config.include_num_access is None:
|
||||
host_config.include_num_access = config.include_num_access
|
||||
|
||||
for router_config in config.routers:
|
||||
if router_config.num_ports is None:
|
||||
router_config.num_ports = config.num_ports
|
||||
if router_config.ip_list is None:
|
||||
router_config.ip_list = config.ip_list
|
||||
if router_config.wildcard_list is None:
|
||||
router_config.wildcard_list = config.wildcard_list
|
||||
if router_config.port_list is None:
|
||||
router_config.port_list = config.port_list
|
||||
if router_config.protocol_list is None:
|
||||
router_config.protocol_list = config.protocol_list
|
||||
if router_config.num_rules is None:
|
||||
router_config.num_rules = config.num_rules
|
||||
|
||||
for firewall_config in config.firewalls:
|
||||
if firewall_config.ip_list is None:
|
||||
firewall_config.ip_list = config.ip_list
|
||||
if firewall_config.wildcard_list is None:
|
||||
firewall_config.wildcard_list = config.wildcard_list
|
||||
if firewall_config.port_list is None:
|
||||
firewall_config.port_list = config.port_list
|
||||
if firewall_config.protocol_list is None:
|
||||
firewall_config.protocol_list = config.protocol_list
|
||||
if firewall_config.num_rules is None:
|
||||
firewall_config.num_rules = config.num_rules
|
||||
|
||||
hosts = [HostObservation.from_config(config=c, game=game, parent_where=where) for c in config.hosts]
|
||||
routers = [RouterObservation.from_config(config=c, game=game, parent_where=where) for c in config.routers]
|
||||
firewalls = [FirewallObservation.from_config(config=c, game=game, parent_where=where) for c in config.firewalls]
|
||||
|
||||
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)
|
||||
|
||||
@@ -1,18 +1,149 @@
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
|
||||
|
||||
from primaite.game.agent.observations.agent_observations import (
|
||||
UC2BlueObservation,
|
||||
UC2GreenObservation,
|
||||
UC2RedObservation,
|
||||
)
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NestedObservation(AbstractObservation, identifier="CUSTOM"):
|
||||
"""Observation type that allows combining other observations into a gymnasium.spaces.Dict space."""
|
||||
|
||||
class NestedObservationItem(BaseModel):
|
||||
"""One list item of the config."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
type: str
|
||||
"""Select observation class. It maps to the identifier of the obs class by checking the registry."""
|
||||
label: str
|
||||
"""Dict key in the final observation space."""
|
||||
options: Dict
|
||||
"""Options to pass to the observation class from_config method."""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_model(self) -> "NestedObservation.NestedObservationItem":
|
||||
"""Make sure tha the config options match up with the selected observation type."""
|
||||
obs_subclass_name = self.type
|
||||
obs_options = self.options
|
||||
if obs_subclass_name not in AbstractObservation._registry:
|
||||
raise ValueError(f"Observation of type {obs_subclass_name} could not be found.")
|
||||
obs_schema = AbstractObservation._registry[obs_subclass_name].ConfigSchema
|
||||
try:
|
||||
obs_schema(**obs_options)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Observation options did not match schema, got this error: {e}")
|
||||
return self
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for NestedObservation."""
|
||||
|
||||
components: List[NestedObservation.NestedObservationItem] = []
|
||||
"""List of observation components to be part of this space."""
|
||||
|
||||
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
|
||||
"""Initialise nested observation."""
|
||||
self.components: Dict[str, AbstractObservation] = components
|
||||
"""Maps label: observation object"""
|
||||
|
||||
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
|
||||
"""Default observation is just the default observations of constituents."""
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status information about the host.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
return {label: obs.observe(state) for label, obs in self.components.items()}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the nested observation space.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NestedObservation:
|
||||
"""
|
||||
Read the Nested observation config and create all defined subcomponents.
|
||||
|
||||
Example configuration that utilises NestedObservation:
|
||||
This lets us have different options for different types of hosts.
|
||||
|
||||
```yaml
|
||||
observation_space:
|
||||
- type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
|
||||
- type: HOSTS
|
||||
label: COMPUTERS # What is the dictionary key called
|
||||
options:
|
||||
hosts:
|
||||
- client_1
|
||||
- client_2
|
||||
num_services: 0
|
||||
num_applications: 5
|
||||
... # other options
|
||||
|
||||
- type: HOSTS
|
||||
label: SERVERS # What is the dictionary key called
|
||||
options:
|
||||
hosts:
|
||||
- hostname: database_server
|
||||
- hostname: web_server
|
||||
num_services: 4
|
||||
num_applications: 0
|
||||
num_folders: 2
|
||||
num_files: 2
|
||||
|
||||
```
|
||||
"""
|
||||
instances = dict()
|
||||
for component in config.components:
|
||||
obs_class = AbstractObservation._registry[component.type]
|
||||
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options), game=game)
|
||||
instances[component.label] = obs_instance
|
||||
return cls(components=instances)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation, identifier="NONE"):
|
||||
"""Empty observation that acts as a placeholder."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise the empty observation."""
|
||||
self.default_observation = 0
|
||||
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""Simply return 0."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Essentially empty space."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: NullObservation.ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> NullObservation:
|
||||
"""Instantiate a NullObservation. Accepts parameters to comply with API."""
|
||||
return cls()
|
||||
|
||||
|
||||
class ObservationManager:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
@@ -23,18 +154,15 @@ class ObservationManager:
|
||||
3. Formatting this information so an agent can use it to make decisions.
|
||||
"""
|
||||
|
||||
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
|
||||
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
|
||||
# refactor.
|
||||
|
||||
def __init__(self, observation: AbstractObservation) -> None:
|
||||
def __init__(self, obs: AbstractObservation) -> None:
|
||||
"""Initialise observation space.
|
||||
|
||||
:param observation: Observation object
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
self.obs: AbstractObservation = obs
|
||||
self.current_observation: ObsType
|
||||
"""Cached copy of the observation at the time it was most recently calculated."""
|
||||
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
@@ -52,22 +180,25 @@ class ObservationManager:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
def from_config(cls, config: Optional[Dict], game: "PrimaiteGame") -> "ObservationManager":
|
||||
"""
|
||||
Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
It should contain the key 'type' which selects which observation class to use (from a choice of:
|
||||
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
|
||||
The other key is 'options' which are passed to the constructor of the selected observation class.
|
||||
If None, a blank observation space is created.
|
||||
Otherwise, this must be a Dict with a type field and options field.
|
||||
type: string that corresponds to one of the observation identifiers that are provided when subclassing
|
||||
AbstractObservation
|
||||
options: this must adhere to the chosen observation type's ConfigSchema nested class.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
if config is None:
|
||||
return cls(NullObservation())
|
||||
print(config)
|
||||
obs_type = config["type"]
|
||||
obs_class = AbstractObservation._registry[obs_type]
|
||||
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)
|
||||
obs_manager = cls(observation)
|
||||
return obs_manager
|
||||
|
||||
@@ -1,22 +1,50 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Dict, Iterable, Type, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
WhereType = Iterable[str | int] | None
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
"""Abstract class for an observation space component."""
|
||||
|
||||
class ConfigSchema(ABC, BaseModel):
|
||||
"""Config schema for observations."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
_registry: Dict[str, Type["AbstractObservation"]] = {}
|
||||
"""Registry of observation components, with their name as key.
|
||||
|
||||
Automatically populated when subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialise an observation. This method must be overwritten."""
|
||||
self.default_observation: ObsType
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register an observation type.
|
||||
|
||||
:param identifier: Identifier used to uniquely specify observation component types.
|
||||
:type identifier: str
|
||||
:raises ValueError: When attempting to create a component with a name that is already in use.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate observation component type {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
@@ -37,273 +65,8 @@ class AbstractObservation(ABC):
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `game` parameter is for a the PrimaiteGame object that spawns this component.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: List[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_hostname>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_hostname"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.network_interface[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_hostname = config["router_hostname"]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
def from_config(
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> "AbstractObservation":
|
||||
"""Create this observation space component form a serialised format."""
|
||||
return cls()
|
||||
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
pass
|
||||
|
||||
147
src/primaite/game/agent/observations/router_observation.py
Normal file
147
src/primaite/game/agent/observations/router_observation.py
Normal file
@@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class RouterObservation(AbstractObservation, identifier="ROUTER"):
|
||||
"""Router observation, provides status information about a router within the simulation environment."""
|
||||
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for RouterObservation."""
|
||||
|
||||
hostname: str
|
||||
"""Hostname of the router, used for querying simulation state dictionary."""
|
||||
ports: Optional[List[PortObservation.ConfigSchema]] = None
|
||||
"""Configuration of port observations for this router."""
|
||||
num_ports: Optional[int] = None
|
||||
"""Number of port observations configured for this router."""
|
||||
acl: Optional[ACLObservation.ConfigSchema] = None
|
||||
"""Configuration of ACL observation on this router."""
|
||||
ip_list: Optional[List[str]] = None
|
||||
"""List of IP addresses for encoding ACLs."""
|
||||
wildcard_list: Optional[List[str]] = None
|
||||
"""List of IP wildcards for encoding ACLs."""
|
||||
port_list: Optional[List[int]] = None
|
||||
"""List of ports for encoding ACLs."""
|
||||
protocol_list: Optional[List[str]] = None
|
||||
"""List of protocols for encoding ACLs."""
|
||||
num_rules: Optional[int] = None
|
||||
"""Number of rules ACL rules to show."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: WhereType,
|
||||
ports: List[PortObservation],
|
||||
num_ports: int,
|
||||
acl: ACLObservation,
|
||||
) -> None:
|
||||
"""
|
||||
Initialise a router observation instance.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this router.
|
||||
A typical location for a router might be ['network', 'nodes', <node_hostname>].
|
||||
:type where: WhereType
|
||||
:param ports: List of port observations representing the ports of the router.
|
||||
:type ports: List[PortObservation]
|
||||
:param num_ports: Number of ports for the router.
|
||||
:type num_ports: int
|
||||
:param acl: ACL observation representing the access control list of the router.
|
||||
:type acl: ACLObservation
|
||||
"""
|
||||
self.where: WhereType = where
|
||||
self.ports: List[PortObservation] = ports
|
||||
self.acl: ACLObservation = acl
|
||||
self.num_ports: int = num_ports
|
||||
|
||||
while len(self.ports) < num_ports:
|
||||
self.ports.append(PortObservation(where=None))
|
||||
while len(self.ports) > num_ports:
|
||||
self.ports.pop()
|
||||
msg = "Too many ports in router observation. Truncating."
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"ACL": self.acl.default_observation,
|
||||
}
|
||||
if self.ports:
|
||||
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation containing the status of ports and ACL configuration of the router.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
router_state = access_from_nested_dict(state, self.where)
|
||||
if router_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
if self.ports:
|
||||
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for router status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
shape = {"ACL": self.acl.space}
|
||||
if self.ports:
|
||||
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
|
||||
return spaces.Dict(shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation:
|
||||
"""
|
||||
Create a router observation from a configuration schema.
|
||||
|
||||
:param config: Configuration schema containing the necessary information for the router observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this router's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed router observation instance.
|
||||
:rtype: RouterObservation
|
||||
"""
|
||||
where = parent_where + ["nodes", config.hostname]
|
||||
|
||||
if config.acl is None:
|
||||
config.acl = ACLObservation.ConfigSchema()
|
||||
if config.acl.num_rules is None:
|
||||
config.acl.num_rules = config.num_rules
|
||||
if config.acl.ip_list is None:
|
||||
config.acl.ip_list = config.ip_list
|
||||
if config.acl.wildcard_list is None:
|
||||
config.acl.wildcard_list = config.wildcard_list
|
||||
if config.acl.port_list is None:
|
||||
config.acl.port_list = config.port_list
|
||||
if config.acl.protocol_list is None:
|
||||
config.acl.protocol_list = config.protocol_list
|
||||
|
||||
if config.ports is None:
|
||||
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
|
||||
|
||||
ports = [PortObservation.from_config(config=c, game=game, parent_where=where) for c in config.ports]
|
||||
acl = ACLObservation.from_config(config=config.acl, game=game, parent_where=where)
|
||||
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)
|
||||
@@ -1,45 +1,46 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation):
|
||||
"""Observation of a service in the network."""
|
||||
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
|
||||
"""Service observation, shows status of a service in the simulation environment."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ServiceObservation."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise service observation.
|
||||
service_name: str
|
||||
"""Name of the service, used for querying simulation state dictionary"""
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'services', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialise a service observation instance.
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this service.
|
||||
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0}
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Observation containing the operating status and health status of the service.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
@@ -50,114 +51,120 @@ class ServiceObservation(AbstractObservation):
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for service status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> ServiceObservation:
|
||||
"""
|
||||
Create a service observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:param config: Configuration schema containing the necessary information for the service observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this service's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed service observation instance.
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config["service_name"]])
|
||||
return cls(where=parent_where + ["services", config.service_name])
|
||||
|
||||
|
||||
class ApplicationObservation(AbstractObservation):
|
||||
"""Observation of an application in the network."""
|
||||
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
|
||||
"""Application observation, shows the status of an application within the simulation environment."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
"Default observation is what should be returned when the application doesn't exist."
|
||||
class ConfigSchema(AbstractObservation.ConfigSchema):
|
||||
"""Configuration schema for ApplicationObservation."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise application observation.
|
||||
application_name: str
|
||||
"""Name of the application, used for querying simulation state dictionary"""
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'applications', <application_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
def __init__(self, where: WhereType) -> None:
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
Initialise an application observation instance.
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this application.
|
||||
A typical location for an application might be
|
||||
['network', 'nodes', <node_hostname>, 'applications', <application_name>].
|
||||
:type where: WhereType
|
||||
"""
|
||||
self.where = where
|
||||
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
# TODO: allow these to be configured in yaml
|
||||
self.high_threshold = 10
|
||||
self.med_threshold = 5
|
||||
self.low_threshold = 0
|
||||
|
||||
def _categorise_num_executions(self, num_executions: int) -> int:
|
||||
"""
|
||||
Represent number of file accesses as a categorical variable.
|
||||
|
||||
:param num_access: Number of file accesses.
|
||||
:return: Bin number corresponding to the number of accesses.
|
||||
"""
|
||||
if num_executions > self.high_threshold:
|
||||
return 3
|
||||
elif num_executions > self.med_threshold:
|
||||
return 2
|
||||
elif num_executions > self.low_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary.
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
:return: Obs containing the operating status, health status, and number of executions of the application.
|
||||
:rtype: ObsType
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
app_state = access_from_nested_dict(state, self.where)
|
||||
if app_state is NOT_PRESENT_IN_STATE:
|
||||
application_state = access_from_nested_dict(state, self.where)
|
||||
if application_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": app_state["operating_state"],
|
||||
"health_status": app_state["health_state_visible"],
|
||||
"num_executions": self._categorise_num_executions(app_state["num_executions"]),
|
||||
"operating_status": application_state["operating_state"],
|
||||
"health_status": application_state["health_state_visible"],
|
||||
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space representing the observation space for application status.
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"operating_status": spaces.Discrete(7),
|
||||
"health_status": spaces.Discrete(6),
|
||||
"health_status": spaces.Discrete(5),
|
||||
"num_executions": spaces.Discrete(4),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ApplicationObservation":
|
||||
"""Create application observation from a config.
|
||||
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
|
||||
) -> ApplicationObservation:
|
||||
"""
|
||||
Create an application observation from a configuration schema.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:param config: Configuration schema containing the necessary information for the application observation.
|
||||
:type config: ConfigSchema
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this application's
|
||||
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
|
||||
:type parent_where: WhereType, optional
|
||||
:return: Constructed application observation instance.
|
||||
:rtype: ApplicationObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config["application_name"]])
|
||||
|
||||
@classmethod
|
||||
def _categorise_num_executions(cls, num_executions: int) -> int:
|
||||
"""
|
||||
Categorise the number of executions of an application.
|
||||
|
||||
Helps classify the number of application executions into different categories.
|
||||
|
||||
Current categories:
|
||||
- 0: Application is never executed
|
||||
- 1: Application is executed a low number of times (1-5)
|
||||
- 2: Application is executed often (6-10)
|
||||
- 3: Application is executed a high number of times (more than 10)
|
||||
|
||||
:param: num_executions: Number of times the application is executed
|
||||
"""
|
||||
if num_executions > 10:
|
||||
return 3
|
||||
elif num_executions > 5:
|
||||
return 2
|
||||
elif num_executions > 0:
|
||||
return 1
|
||||
return 0
|
||||
return cls(where=parent_where + ["applications", config.application_name])
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Hashable, Sequence
|
||||
from typing import Any, Dict, Hashable, Optional, Sequence
|
||||
|
||||
NOT_PRESENT_IN_STATE = object()
|
||||
"""
|
||||
@@ -7,7 +7,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is
|
||||
"""
|
||||
|
||||
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
def access_from_nested_dict(dictionary: Dict, keys: Optional[Sequence[Hashable]]) -> Any:
|
||||
"""
|
||||
Access an item from a deeply dictionary with a list of keys.
|
||||
|
||||
@@ -21,6 +21,8 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
|
||||
:return: The value in the dictionary
|
||||
:rtype: Any
|
||||
"""
|
||||
if keys is None:
|
||||
return NOT_PRESENT_IN_STATE
|
||||
key_list = [*keys] # copy keys to a new list to prevent editing original list
|
||||
if len(key_list) == 0:
|
||||
return dictionary
|
||||
|
||||
@@ -148,8 +148,10 @@ class ACLRule(SimComponent):
|
||||
state["action"] = self.action.value
|
||||
state["protocol"] = self.protocol.name if self.protocol else None
|
||||
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
|
||||
state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None
|
||||
state["src_port"] = self.src_port.name if self.src_port else None
|
||||
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
|
||||
state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None
|
||||
state["dst_port"] = self.dst_port.name if self.dst_port else None
|
||||
state["match_count"] = self.match_count
|
||||
return state
|
||||
|
||||
@@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True
|
||||
NMNE_CAPTURE_KEYWORDS: List[str] = []
|
||||
"""List of keywords to identify malicious network events."""
|
||||
|
||||
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
|
||||
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
|
||||
CAPTURE_BY_DIRECTION: Final[bool] = True
|
||||
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
|
||||
CAPTURE_BY_IP_ADDRESS: Final[bool] = False
|
||||
|
||||
@@ -22,8 +22,7 @@ agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -50,10 +49,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -86,63 +82,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -41,8 +41,7 @@ agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
|
||||
@@ -41,8 +41,7 @@ agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
|
||||
@@ -66,8 +66,7 @@ agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
|
||||
@@ -26,8 +26,7 @@ agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -58,10 +57,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -102,63 +98,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -64,25 +64,67 @@ agents:
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: client_1
|
||||
links:
|
||||
- link_ref: client_1___switch_1
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.0.10
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- client_1___switch_1
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
# observation_space:
|
||||
# type: UC2BlueObservation
|
||||
# options:
|
||||
# num_services_per_node: 1
|
||||
# num_folders_per_node: 1
|
||||
# num_files_per_folder: 1
|
||||
# num_nics_per_node: 2
|
||||
# nodes:
|
||||
# - node_hostname: client_1
|
||||
# links:
|
||||
# - link_ref: client_1___switch_1
|
||||
# acl:
|
||||
# options:
|
||||
# max_acl_rules: 10
|
||||
# router_hostname: router_1
|
||||
# ip_address_order:
|
||||
# - node_hostname: client_1
|
||||
# nic_num: 1
|
||||
# ics: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
|
||||
@@ -32,8 +32,7 @@ agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -61,10 +60,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -97,63 +93,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -553,63 +559,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -41,8 +41,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -91,8 +90,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -141,10 +139,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -177,61 +172,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -41,8 +41,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -91,8 +90,7 @@ agents:
|
||||
0: 0.3
|
||||
1: 0.6
|
||||
2: 0.1
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -141,10 +139,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -177,61 +172,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: DNSServer
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- node_hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -33,8 +33,7 @@ agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -65,10 +64,7 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
@@ -110,65 +106,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
# services:
|
||||
# - service_name: backup_service
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -26,8 +26,7 @@ agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -65,10 +64,8 @@ agents:
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes: {}
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
@@ -109,63 +106,73 @@ agents:
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
type: CUSTOM
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_hostname: domain_controller
|
||||
services:
|
||||
- service_name: domain_controller_dns_server
|
||||
- node_hostname: web_server
|
||||
services:
|
||||
- service_name: web_server_database_client
|
||||
- node_hostname: database_server
|
||||
services:
|
||||
- service_name: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_hostname: backup_server
|
||||
- node_hostname: security_suite
|
||||
- node_hostname: client_1
|
||||
- node_hostname: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_hostname: router_1
|
||||
ip_address_order:
|
||||
- node_hostname: domain_controller
|
||||
nic_num: 1
|
||||
- node_hostname: web_server
|
||||
nic_num: 1
|
||||
- node_hostname: database_server
|
||||
nic_num: 1
|
||||
- node_hostname: backup_server
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 1
|
||||
- node_hostname: client_1
|
||||
nic_num: 1
|
||||
- node_hostname: client_2
|
||||
nic_num: 1
|
||||
- node_hostname: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1___switch_1
|
||||
- router_1___switch_2
|
||||
- switch_1___domain_controller
|
||||
- switch_1___web_server
|
||||
- switch_1___database_server
|
||||
- switch_1___backup_server
|
||||
- switch_1___security_suite
|
||||
- switch_2___client_1
|
||||
- switch_2___client_2
|
||||
- switch_2___security_suite
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
|
||||
@@ -10,8 +10,7 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.observations.observations import ICSObservation
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.session import PrimaiteSession
|
||||
@@ -533,7 +532,7 @@ def game_and_agent():
|
||||
ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
|
||||
act_map={},
|
||||
)
|
||||
observation_space = ObservationManager(ICSObservation())
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
|
||||
test_agent = ControlledAgent(
|
||||
|
||||
@@ -13,7 +13,9 @@ MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
class TestPrimaiteSession:
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_creating_session(self, temp_primaite_session):
|
||||
"""Check that creating a session from config works."""
|
||||
@@ -56,6 +58,7 @@ class TestPrimaiteSession:
|
||||
assert checkpoint_2.exists()
|
||||
assert not checkpoint_3.exists()
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
|
||||
def test_training_only_session(self, temp_primaite_session):
|
||||
"""Check that you can run a training-only session."""
|
||||
@@ -64,6 +67,7 @@ class TestPrimaiteSession:
|
||||
session.start_session()
|
||||
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
|
||||
def test_eval_only_session(self, temp_primaite_session):
|
||||
"""Check that you can load a model and run an eval-only session."""
|
||||
@@ -72,6 +76,7 @@ class TestPrimaiteSession:
|
||||
session.start_session()
|
||||
# TODO: include checks that the model was loaded and that the eval-only session ran
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.skip(reason="Slow, reenable later")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
|
||||
def test_multi_agent_session(self, temp_primaite_session):
|
||||
@@ -79,10 +84,12 @@ class TestPrimaiteSession:
|
||||
with temp_primaite_session as session:
|
||||
session.start_session()
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.skip(
|
||||
reason="Currently software cannot be dynamically created/destroyed during simulation. Therefore, "
|
||||
"reset doesn't implement software restore."
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.observations.observations import AclObservation
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
@@ -34,11 +34,13 @@ def test_acl_observations(simulation):
|
||||
# add router acl rule
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1)
|
||||
|
||||
acl_obs = AclObservation(
|
||||
acl_obs = ACLObservation(
|
||||
where=["network", "nodes", router.hostname, "acl", "acl"],
|
||||
node_ip_to_id={},
|
||||
ports=["NTP", "HTTP", "POSTGRES_SERVER"],
|
||||
protocols=["TCP", "UDP", "ICMP"],
|
||||
ip_list=[],
|
||||
port_list=["NTP", "HTTP", "POSTGRES_SERVER"],
|
||||
protocol_list=["TCP", "UDP", "ICMP"],
|
||||
num_rules=10,
|
||||
wildcard_list=[],
|
||||
)
|
||||
|
||||
observation_space = acl_obs.observe(simulation.describe_state())
|
||||
@@ -46,11 +48,11 @@ def test_acl_observations(simulation):
|
||||
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
||||
assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1)
|
||||
assert rule_obs.get("permission") == 1 # permit = 1 deny = 2
|
||||
assert rule_obs.get("source_node_id") == 1 # applies to all source nodes
|
||||
assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes
|
||||
assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs)
|
||||
assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2
|
||||
assert rule_obs.get("protocol") == 1 # 1 = No Protocol
|
||||
assert rule_obs.get("source_ip_id") == 1 # applies to all source nodes
|
||||
assert rule_obs.get("dest_ip_id") == 1 # applies to all destination nodes
|
||||
assert rule_obs.get("source_port_id") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs)
|
||||
assert rule_obs.get("dest_port_id") == 2 # NTP port is mapped to value 2
|
||||
assert rule_obs.get("protocol_id") == 1 # 1 = No Protocol
|
||||
|
||||
router.acl.remove_rule(1)
|
||||
|
||||
@@ -59,8 +61,8 @@ def test_acl_observations(simulation):
|
||||
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
||||
assert rule_obs.get("position") == 0
|
||||
assert rule_obs.get("permission") == 0
|
||||
assert rule_obs.get("source_node_id") == 0
|
||||
assert rule_obs.get("dest_node_id") == 0
|
||||
assert rule_obs.get("source_port") == 0
|
||||
assert rule_obs.get("dest_port") == 0
|
||||
assert rule_obs.get("protocol") == 0
|
||||
assert rule_obs.get("source_ip_id") == 0
|
||||
assert rule_obs.get("dest_ip_id") == 0
|
||||
assert rule_obs.get("source_port_id") == 0
|
||||
assert rule_obs.get("dest_port_id") == 0
|
||||
assert rule_obs.get("protocol_id") == 0
|
||||
|
||||
@@ -23,7 +23,8 @@ def test_file_observation(simulation):
|
||||
file = pc.file_system.create_file(file_name="dog.png")
|
||||
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
)
|
||||
|
||||
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
|
||||
@@ -49,7 +50,10 @@ def test_folder_observation(simulation):
|
||||
file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder")
|
||||
|
||||
root_folder_obs = FolderObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"]
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
|
||||
include_num_access=False,
|
||||
num_files=1,
|
||||
files=[],
|
||||
)
|
||||
|
||||
assert root_folder_obs.space["health_status"] == spaces.Discrete(6)
|
||||
@@ -68,3 +72,6 @@ def test_folder_observation(simulation):
|
||||
|
||||
observation_state = root_folder_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too
|
||||
|
||||
|
||||
# TODO: Add tests to check num access is correct.
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
from primaite.game.agent.observations.firewall_observation import FirewallObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
|
||||
|
||||
def check_default_rules(acl_obs):
|
||||
assert len(acl_obs) == 7
|
||||
assert all(acl_obs[i]["position"] == i - 1 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["source_port_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(1, 8))
|
||||
assert all(acl_obs[i]["protocol_id"] == 0 for i in range(1, 8))
|
||||
|
||||
|
||||
def test_firewall_observation():
|
||||
"""Test adding/removing acl rules and enabling/disabling ports."""
|
||||
net = Network()
|
||||
firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON)
|
||||
firewall_observation = FirewallObservation(
|
||||
where=[],
|
||||
num_rules=7,
|
||||
ip_list=["10.0.0.1", "10.0.0.2"],
|
||||
wildcard_list=["0.0.0.255", "0.0.0.1"],
|
||||
port_list=["HTTP", "DNS"],
|
||||
protocol_list=["TCP"],
|
||||
)
|
||||
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
assert "ACL" in observation
|
||||
assert "PORTS" in observation
|
||||
assert "INTERNAL" in observation["ACL"]
|
||||
assert "EXTERNAL" in observation["ACL"]
|
||||
assert "DMZ" in observation["ACL"]
|
||||
assert "INBOUND" in observation["ACL"]["INTERNAL"]
|
||||
assert "OUTBOUND" in observation["ACL"]["INTERNAL"]
|
||||
assert "INBOUND" in observation["ACL"]["EXTERNAL"]
|
||||
assert "OUTBOUND" in observation["ACL"]["EXTERNAL"]
|
||||
assert "INBOUND" in observation["ACL"]["DMZ"]
|
||||
assert "OUTBOUND" in observation["ACL"]["DMZ"]
|
||||
all_acls = (
|
||||
observation["ACL"]["INTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["INTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["DMZ"]["INBOUND"],
|
||||
observation["ACL"]["DMZ"]["OUTBOUND"],
|
||||
)
|
||||
for acl_obs in all_acls:
|
||||
check_default_rules(acl_obs)
|
||||
|
||||
# add a rule to the internal inbound and check that the observation is correct
|
||||
firewall.internal_inbound_acl.add_rule(
|
||||
action=ACLAction.DENY,
|
||||
protocol=IPProtocol.TCP,
|
||||
src_ip_address="10.0.0.1",
|
||||
src_wildcard_mask="0.0.0.1",
|
||||
dst_ip_address="10.0.0.2",
|
||||
dst_wildcard_mask="0.0.0.1",
|
||||
src_port=Port.HTTP,
|
||||
dst_port=Port.HTTP,
|
||||
position=5,
|
||||
)
|
||||
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5]
|
||||
assert observed_rule["position"] == 4
|
||||
assert observed_rule["permission"] == 2
|
||||
assert observed_rule["source_ip_id"] == 2
|
||||
assert observed_rule["source_wildcard_id"] == 3
|
||||
assert observed_rule["source_port_id"] == 2
|
||||
assert observed_rule["dest_ip_id"] == 3
|
||||
assert observed_rule["dest_wildcard_id"] == 3
|
||||
assert observed_rule["dest_port_id"] == 2
|
||||
assert observed_rule["protocol_id"] == 2
|
||||
|
||||
# check that none of the other acls have changed
|
||||
all_acls = (
|
||||
observation["ACL"]["INTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["DMZ"]["INBOUND"],
|
||||
observation["ACL"]["DMZ"]["OUTBOUND"],
|
||||
)
|
||||
for acl_obs in all_acls:
|
||||
check_default_rules(acl_obs)
|
||||
|
||||
# remove the rule and check that the observation is correct
|
||||
firewall.internal_inbound_acl.remove_rule(5)
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
all_acls = (
|
||||
observation["ACL"]["INTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["INTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["INBOUND"],
|
||||
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
|
||||
observation["ACL"]["DMZ"]["INBOUND"],
|
||||
observation["ACL"]["DMZ"]["OUTBOUND"],
|
||||
)
|
||||
for acl_obs in all_acls:
|
||||
check_default_rules(acl_obs)
|
||||
|
||||
# check that there are three ports in the observation
|
||||
assert len(observation["PORTS"]) == 3
|
||||
|
||||
# check that the ports are all disabled
|
||||
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4))
|
||||
|
||||
# connect a switch to the firewall and check that only the correct port is updated
|
||||
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
|
||||
link = net.connect(firewall.network_interface[1], switch.network_interface[1])
|
||||
assert firewall.network_interface[1].enabled
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
assert observation["PORTS"][1]["operating_status"] == 1
|
||||
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(2, 4))
|
||||
|
||||
# disable the port and check that the operating status is updated
|
||||
firewall.network_interface[1].disable()
|
||||
assert not firewall.network_interface[1].enabled
|
||||
observation = firewall_observation.observe(firewall.describe_state())
|
||||
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4))
|
||||
@@ -1,11 +1,13 @@
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.observations import LinkObservation
|
||||
from primaite.game.agent.observations.link_observation import LinkObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import Link, Node
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
@@ -49,25 +51,44 @@ def simulation() -> Simulation:
|
||||
return sim
|
||||
|
||||
|
||||
def test_link_observation(simulation):
|
||||
"""Test the link observation."""
|
||||
# get a link
|
||||
link: Link = next(iter(simulation.network.links.values()))
|
||||
def test_link_observation():
|
||||
"""Check the shape and contents of the link observation."""
|
||||
net = Network()
|
||||
sim = Simulation(network=net)
|
||||
switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON)
|
||||
computer_1 = Computer(
|
||||
hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0
|
||||
)
|
||||
computer_2 = Computer(
|
||||
hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0
|
||||
)
|
||||
computer_1.power_on()
|
||||
computer_2.power_on()
|
||||
link_1 = net.connect(switch.network_interface[1], computer_1.network_interface[1])
|
||||
link_2 = net.connect(switch.network_interface[2], computer_2.network_interface[1])
|
||||
assert link_1 is not None
|
||||
assert link_2 is not None
|
||||
|
||||
computer: Computer = simulation.network.get_node_by_hostname("computer")
|
||||
server: Server = simulation.network.get_node_by_hostname("server")
|
||||
link_1_observation = LinkObservation(where=["network", "links", link_1.uuid])
|
||||
link_2_observation = LinkObservation(where=["network", "links", link_2.uuid])
|
||||
|
||||
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
|
||||
state = sim.describe_state()
|
||||
link_1_obs = link_1_observation.observe(state)
|
||||
link_2_obs = link_2_observation.observe(state)
|
||||
assert "PROTOCOLS" in link_1_obs
|
||||
assert "PROTOCOLS" in link_2_obs
|
||||
assert "ALL" in link_1_obs["PROTOCOLS"]
|
||||
assert "ALL" in link_2_obs["PROTOCOLS"]
|
||||
assert link_1_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
|
||||
assert link_2_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
|
||||
assert link_1_obs["PROTOCOLS"]["ALL"] == 0
|
||||
assert link_2_obs["PROTOCOLS"]["ALL"] == 0
|
||||
|
||||
link_obs = LinkObservation(where=["network", "links", link.uuid])
|
||||
|
||||
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("PROTOCOLS") is not None
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 0
|
||||
|
||||
computer.ping(server.network_interface.get(1).ip_address)
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 1
|
||||
# Test that the link observation is updated when a packet is sent
|
||||
computer_1.ping("10.0.0.2")
|
||||
computer_2.ping("10.0.0.1")
|
||||
state = sim.describe_state()
|
||||
link_1_obs = link_1_observation.observe(state)
|
||||
link_2_obs = link_2_observation.observe(state)
|
||||
assert link_1_obs["PROTOCOLS"]["ALL"] > 0
|
||||
assert link_2_obs["PROTOCOLS"]["ALL"] > 0
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
import yaml
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
@@ -40,7 +40,7 @@ def test_nic(simulation):
|
||||
|
||||
nic: NIC = pc.network_interface[1]
|
||||
|
||||
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
|
||||
|
||||
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
|
||||
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
|
||||
@@ -61,17 +61,22 @@ def test_nic_categories(simulation):
|
||||
"""Test the NIC observation nmne count categories."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 10 # default
|
||||
assert nic_obs.med_nmne_threshold == 5 # default
|
||||
assert nic_obs.low_nmne_threshold == 0 # default
|
||||
|
||||
nic_obs = NicObservation(
|
||||
|
||||
@pytest.mark.skip(reason="Feature not implemented yet")
|
||||
def test_config_nic_categories(simulation):
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
nic_obs = NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 9
|
||||
@@ -80,18 +85,20 @@ def test_nic_categories(simulation):
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NicObservation(
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=9,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NicObservation(
|
||||
NICObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=9,
|
||||
high_nmne_threshold=9,
|
||||
include_nmne=True,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.game.agent.observations.host_observations import HostObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
@@ -19,15 +19,28 @@ def simulation(example_network) -> Simulation:
|
||||
return sim
|
||||
|
||||
|
||||
def test_node_observation(simulation):
|
||||
"""Test a Node observation."""
|
||||
def test_host_observation(simulation):
|
||||
"""Test a Host observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
node_obs = NodeObservation(where=["network", "nodes", pc.hostname])
|
||||
host_obs = HostObservation(
|
||||
where=["network", "nodes", pc.hostname],
|
||||
num_applications=0,
|
||||
num_files=1,
|
||||
num_folders=1,
|
||||
num_nics=2,
|
||||
num_services=1,
|
||||
include_num_access=False,
|
||||
include_nmne=False,
|
||||
services=[],
|
||||
applications=[],
|
||||
folders=[],
|
||||
network_interfaces=[],
|
||||
)
|
||||
|
||||
assert node_obs.space["operating_status"] == spaces.Discrete(5)
|
||||
assert host_obs.space["operating_status"] == spaces.Discrete(5)
|
||||
|
||||
observation_state = node_obs.observe(simulation.describe_state())
|
||||
observation_state = host_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("operating_status") == 1 # computer is on
|
||||
|
||||
assert observation_state.get("SERVICES") is not None
|
||||
@@ -36,11 +49,11 @@ def test_node_observation(simulation):
|
||||
|
||||
# turn off computer
|
||||
pc.power_off()
|
||||
observation_state = node_obs.observe(simulation.describe_state())
|
||||
observation_state = host_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("operating_status") == 4 # shutting down
|
||||
|
||||
for i in range(pc.shut_down_duration + 1):
|
||||
pc.apply_timestep(i)
|
||||
|
||||
observation_state = node_obs.observe(simulation.describe_state())
|
||||
observation_state = host_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("operating_status") == 2
|
||||
|
||||
@@ -0,0 +1,108 @@
|
||||
from pprint import pprint
|
||||
|
||||
from primaite.game.agent.observations.acl_observation import ACLObservation
|
||||
from primaite.game.agent.observations.nic_observations import PortObservation
|
||||
from primaite.game.agent.observations.router_observation import RouterObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
def test_router_observation():
|
||||
"""Test adding/removing acl rules and enabling/disabling ports."""
|
||||
net = Network()
|
||||
router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON)
|
||||
|
||||
ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)]
|
||||
acl = ACLObservation(
|
||||
where=["acl", "acl"],
|
||||
num_rules=7,
|
||||
ip_list=["10.0.0.1", "10.0.0.2"],
|
||||
wildcard_list=["0.0.0.255", "0.0.0.1"],
|
||||
port_list=["HTTP", "DNS"],
|
||||
protocol_list=["TCP"],
|
||||
)
|
||||
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl)
|
||||
|
||||
# Observe the state using the RouterObservation instance
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
|
||||
# Check that the right number of ports and acls are in the router observation
|
||||
assert len(observed_output["PORTS"]) == 8
|
||||
assert len(observed_output["ACL"]) == 7
|
||||
|
||||
# Add an ACL rule to the router
|
||||
router.acl.add_rule(
|
||||
action=ACLAction.DENY,
|
||||
protocol=IPProtocol.TCP,
|
||||
src_ip_address="10.0.0.1",
|
||||
src_wildcard_mask="0.0.0.1",
|
||||
dst_ip_address="10.0.0.2",
|
||||
dst_wildcard_mask="0.0.0.1",
|
||||
src_port=Port.HTTP,
|
||||
dst_port=Port.HTTP,
|
||||
position=5,
|
||||
)
|
||||
# Observe the state using the RouterObservation instance
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
observed_rule = observed_output["ACL"][5]
|
||||
assert observed_rule["position"] == 4
|
||||
assert observed_rule["permission"] == 2
|
||||
assert observed_rule["source_ip_id"] == 2
|
||||
assert observed_rule["source_wildcard_id"] == 3
|
||||
assert observed_rule["source_port_id"] == 2
|
||||
assert observed_rule["dest_ip_id"] == 3
|
||||
assert observed_rule["dest_wildcard_id"] == 3
|
||||
assert observed_rule["dest_port_id"] == 2
|
||||
assert observed_rule["protocol_id"] == 2
|
||||
|
||||
# Add an ACL rule with ALL/NONE values and check that the observation is correct
|
||||
router.acl.add_rule(
|
||||
action=ACLAction.PERMIT,
|
||||
protocol=None,
|
||||
src_ip_address=None,
|
||||
src_wildcard_mask=None,
|
||||
dst_ip_address=None,
|
||||
dst_wildcard_mask=None,
|
||||
src_port=None,
|
||||
dst_port=None,
|
||||
position=2,
|
||||
)
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
observed_rule = observed_output["ACL"][2]
|
||||
assert observed_rule["position"] == 1
|
||||
assert observed_rule["permission"] == 1
|
||||
assert observed_rule["source_ip_id"] == 1
|
||||
assert observed_rule["source_wildcard_id"] == 1
|
||||
assert observed_rule["source_port_id"] == 1
|
||||
assert observed_rule["dest_ip_id"] == 1
|
||||
assert observed_rule["dest_wildcard_id"] == 1
|
||||
assert observed_rule["dest_port_id"] == 1
|
||||
assert observed_rule["protocol_id"] == 1
|
||||
|
||||
# Check that the router ports are all disabled
|
||||
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6))
|
||||
|
||||
# connect a switch to the router and check that only the correct port is updated
|
||||
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
|
||||
link = net.connect(router.network_interface[1], switch.network_interface[1])
|
||||
assert router.network_interface[1].enabled
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
assert observed_output["PORTS"][1]["operating_status"] == 1
|
||||
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(2, 6))
|
||||
|
||||
# disable the port and check that the operating status is updated
|
||||
router.network_interface[1].disable()
|
||||
assert not router.network_interface[1].enabled
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6))
|
||||
|
||||
# Check that ports that are out of range are shown as unused
|
||||
observed_output = router_observation.observe(router.describe_state())
|
||||
assert observed_output["PORTS"][6]["operating_status"] == 0
|
||||
assert observed_output["PORTS"][7]["operating_status"] == 0
|
||||
assert observed_output["PORTS"][8]["operating_status"] == 0
|
||||
@@ -14,7 +14,13 @@ def test_file_observation():
|
||||
state = sim.describe_state()
|
||||
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
|
||||
include_num_access=False,
|
||||
)
|
||||
assert dog_file_obs.observe(state) == {"health_status": 1}
|
||||
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
|
||||
# TODO:
|
||||
# def test_file_num_access():
|
||||
# ...
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
@@ -141,9 +141,9 @@ def test_describe_state_nmne(uc2_network):
|
||||
|
||||
def test_capture_nmne_observations(uc2_network):
|
||||
"""
|
||||
Tests the NicObservation class's functionality within a simulated network environment.
|
||||
Tests the NICObservation class's functionality within a simulated network environment.
|
||||
|
||||
This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the
|
||||
This test ensures the observation space, as defined by instances of NICObservation, accurately reflects the
|
||||
number of MNEs detected based on network activities over multiple iterations.
|
||||
|
||||
The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update
|
||||
@@ -168,8 +168,8 @@ def test_capture_nmne_observations(uc2_network):
|
||||
set_nmne_config(nmne_config)
|
||||
|
||||
# Define observations for the NICs of the database and web servers
|
||||
db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1])
|
||||
web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1])
|
||||
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)
|
||||
web_server_nic_obs = NICObservation(where=["network", "nodes", "web_server", "NICs", 1], include_nmne=True)
|
||||
|
||||
# Iterate through a set of test cases to simulate multiple DELETE queries
|
||||
for i in range(0, 20):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.observations.observations import ICSObservation
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
|
||||
@@ -52,7 +51,7 @@ def test_probabilistic_agent():
|
||||
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
|
||||
},
|
||||
)
|
||||
observation_space = ObservationManager(ICSObservation())
|
||||
observation_space = ObservationManager(NestedObservation(components={}))
|
||||
reward_function = RewardFunction()
|
||||
|
||||
pa = ProbabilisticAgent(
|
||||
|
||||
Reference in New Issue
Block a user