Merged PR 321: CAOS 0.8 observations

## Summary
* Remove the usecase-specific and agent-specific observation classes, replacing with a more flexible system
* Add configuration schemas to every observation class
* Add router, firewall, port, and application observation
* Re-shape the dict structure of observations to make it adhere to CAOS 0.8
* Change existing configs to use the new structure
* make host observation separate

## Test process
existing and new unit tests as well as ad hoc notebooks

## Checklist
- [ ] PR is linked to a **work item**
- [ ] **acceptance criteria** of linked ticket are met
- [ ] performed **self-review** of the code
- [ ] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [ ] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2417
This commit is contained in:
Marek Wolan
2024-04-02 14:00:27 +00:00
43 changed files with 2922 additions and 1745 deletions

View File

@@ -120,7 +120,7 @@ SessionManager.
- Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios.
- **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules.
- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them.
- Integration of NMNE capturing functionality within the `NicObservation` class.
- Integration of NMNE capturing functionality within the `NICObservation` class.
- Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario
### Removed

View File

@@ -73,7 +73,7 @@ Network Interface Classes
- Malicious Network Events Monitoring:
* Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns.
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies.
* Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NICObservation`` to classify and record network anomalies.
* Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms.
**WiredNetworkInterface (Connection Type Layer)**

View File

@@ -41,8 +41,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -91,8 +90,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -141,10 +139,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -177,61 +172,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -40,8 +40,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -90,8 +89,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -140,10 +138,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -179,61 +174,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
@@ -742,61 +749,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -0,0 +1,20 @@
# flake8: noqa
# Pre-import all the observations when we load up the observations module so that they can be resolved by the parser.
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
from primaite.game.agent.observations.firewall_observation import FirewallObservation
from primaite.game.agent.observations.host_observations import HostObservation
from primaite.game.agent.observations.link_observation import LinkObservation, LinksObservation
from primaite.game.agent.observations.nic_observations import NICObservation, PortObservation
from primaite.game.agent.observations.node_observations import NodesObservation
from primaite.game.agent.observations.observation_manager import NestedObservation, NullObservation, ObservationManager
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.router_observation import RouterObservation
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
# fmt: off
__all__ = [
"ACLObservation", "FileObservation", "FolderObservation", "FirewallObservation", "HostObservation",
"LinksObservation", "NICObservation", "PortObservation", "NodesObservation", "NestedObservation",
"ObservationManager", "ApplicationObservation", "ServiceObservation",]
# fmt: on

View File

@@ -0,0 +1,189 @@
from __future__ import annotations
from ipaddress import IPv4Address
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class ACLObservation(AbstractObservation, identifier="ACL"):
"""ACL observation, provides information about access control lists within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ACLObservation."""
ip_list: Optional[List[IPv4Address]] = None
"""List of IP addresses."""
wildcard_list: Optional[List[str]] = None
"""List of wildcard strings."""
port_list: Optional[List[int]] = None
"""List of port numbers."""
protocol_list: Optional[List[str]] = None
"""List of protocol names."""
num_rules: Optional[int] = None
"""Number of ACL rules."""
def __init__(
self,
where: WhereType,
num_rules: int,
ip_list: List[IPv4Address],
wildcard_list: List[str],
port_list: List[int],
protocol_list: List[str],
) -> None:
"""
Initialise an ACL observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this ACL.
:type where: WhereType
:param num_rules: Number of ACL rules.
:type num_rules: int
:param ip_list: List of IP addresses.
:type ip_list: List[IPv4Address]
:param wildcard_list: List of wildcard strings.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param protocol_list: List of protocol names.
:type protocol_list: List[str]
"""
self.where = where
self.num_rules: int = num_rules
self.ip_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(ip_list)}
self.wildcard_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(wildcard_list)}
self.port_to_id: Dict[int, int] = {p: i + 2 for i, p in enumerate(port_list)}
self.protocol_to_id: Dict[str, int] = {p: i + 2 for i, p in enumerate(protocol_list)}
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_ip_id": 0,
"source_wildcard_id": 0,
"source_port_id": 0,
"dest_ip_id": 0,
"dest_wildcard_id": 0,
"dest_port_id": 0,
"protocol_id": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing ACL rules.
:rtype: ObsType
"""
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_ip_id": 0,
"source_wildcard_id": 0,
"source_port_id": 0,
"dest_ip_id": 0,
"dest_wildcard_id": 0,
"dest_port_id": 0,
"protocol_id": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.ip_to_id[src_ip]
dst_ip = rule_state["dst_ip_address"]
dst_node_id = 1 if dst_ip is None else self.ip_to_id[dst_ip]
src_wildcard = rule_state["src_wildcard_mask"]
src_wildcard_id = self.wildcard_to_id.get(src_wildcard, 1)
dst_wildcard = rule_state["dst_wildcard_mask"]
dst_wildcard_id = self.wildcard_to_id.get(dst_wildcard, 1)
src_port = rule_state["src_port"]
src_port_id = self.port_to_id.get(src_port, 1)
dst_port = rule_state["dst_port"]
dst_port_id = self.port_to_id.get(dst_port, 1)
protocol = rule_state["protocol"]
protocol_id = self.protocol_to_id.get(protocol, 1)
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_ip_id": src_node_id,
"source_wildcard_id": src_wildcard_id,
"source_port_id": src_port_id,
"dest_ip_id": dst_node_id,
"dest_wildcard_id": dst_wildcard_id,
"dest_port_id": dst_port_id,
"protocol_id": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for ACL rules.
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
"source_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
"source_port_id": spaces.Discrete(len(self.port_to_id) + 2),
"dest_ip_id": spaces.Discrete(len(self.ip_to_id) + 2),
"dest_wildcard_id": spaces.Discrete(len(self.wildcard_to_id) + 2),
"dest_port_id": spaces.Discrete(len(self.port_to_id) + 2),
"protocol_id": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> ACLObservation:
"""
Create an ACL observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the ACL observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this ACL's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed ACL observation instance.
:rtype: ACLObservation
"""
return cls(
where=parent_where + ["acl", "acl"],
num_rules=config.num_rules,
ip_list=config.ip_list,
wildcard_list=config.wildcard_list,
port_list=config.port_list,
protocol_list=config.protocol_list,
)

View File

@@ -1,188 +0,0 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.game.agent.observations.observations import (
AbstractObservation,
AclObservation,
ICSObservation,
LinkObservation,
NullObservation,
)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class UC2BlueObservation(AbstractObservation):
"""Container for all observations used by the blue agent in UC2.
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
for the purpose of compiling several observation components.
"""
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where: Optional[List[str]] = None,
) -> None:
"""Initialise UC2 blue observation.
:param nodes: List of node observations
:type nodes: List[NodeObservation]
:param links: List of link observations
:type links: List[LinkObservation]
:param acl: The Access Control List observation
:type acl: AclObservation
:param ics: The ICS observation
:type ics: ICSObservation
:param where: Where in the simulation state dict to find information. Not used in this particular observation
because it only compiles other observations and doesn't contribute any new information, defaults to None
:type where: Optional[List[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs["ACL"] = self.acl.observe(state)
obs["ICS"] = self.ics.observe(state)
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]
num_nics_per_node = config["num_nics_per_node"]
nodes = [
NodeObservation.from_config(
config=n,
game=game,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
for n in node_configs
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, game=game)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, game=game)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
class UC2RedObservation(AbstractObservation):
"""Container for all observations used by the red agent in UC2."""
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
class UC2GreenObservation(NullObservation):
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
pass

View File

@@ -1,126 +1,170 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, Iterable, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class FileObservation(AbstractObservation):
"""Observation of a file on a node in the network."""
class FileObservation(AbstractObservation, identifier="FILE"):
"""File observation, provides status information about a file within the simulation environment."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FileObservation."""
file_name: str
"""Name of the file, used for querying simulation state dictionary."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to the file in the observation."""
def __init__(self, where: WhereType, include_num_access: bool) -> None:
"""
Initialise file observation.
Initialise a file observation instance.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
:param where: Where in the simulation state dictionary to find the relevant information for this file.
A typical location for a file might be
['network', 'nodes', <node_hostname>, 'file_system', 'folder', <folder_name>, 'files', <file_name>].
:type where: WhereType
:param include_num_access: Whether to include the number of accesses to the file in the observation.
:type include_num_access: bool
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
self.where: WhereType = where
self.include_num_access: bool = include_num_access
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
self.default_observation: ObsType = {"health_status": 0}
if self.include_num_access:
self.default_observation["num_access"] = 0
:param state: Simulation state dictionary
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
def _categorise_num_access(self, num_access: int) -> int:
"""
Represent number of file accesses as a categorical variable.
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_access > self.high_threshold:
return 3
elif num_access > self.med_threshold:
return 2
elif num_access > self.low_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing the health status of the file and optionally the number of accesses.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"health_status": file_state["visible_status"]}
obs = {"health_status": file_state["visible_status"]}
if self.include_num_access:
obs["num_access"] = self._categorise_num_access(file_state["num_access"])
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:return: Gymnasium space representing the observation space for file status.
:rtype: spaces.Space
"""
return spaces.Dict({"health_status": spaces.Discrete(6)})
space = {"health_status": spaces.Discrete(6)}
if self.include_num_access:
space["num_access"] = spaces.Discrete(4)
return spaces.Dict(space)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param game: _description_
:type game: PrimaiteGame
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
:rtype: _type_
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FileObservation:
"""
return cls(where=parent_where + ["files", config["file_name"]])
Create a file observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the file observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this file's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed file observation instance.
:rtype: FileObservation
"""
return cls(where=parent_where + ["files", config.file_name], include_num_access=config.include_num_access)
class FolderObservation(AbstractObservation):
"""Folder observation, including files inside of the folder."""
class FolderObservation(AbstractObservation, identifier="FOLDER"):
"""Folder observation, provides status information about a folder within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FolderObservation."""
folder_name: str
"""Name of the folder, used for querying simulation state dictionary."""
files: List[FileObservation.ConfigSchema] = []
"""List of file configurations within the folder."""
num_files: Optional[int] = None
"""Number of spaces for file observations in this folder."""
include_num_access: Optional[bool] = None
"""Whether files in this folder should include the number of accesses in their observation."""
def __init__(
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
self, where: WhereType, files: Iterable[FileObservation], num_files: int, include_num_access: bool
) -> None:
"""Initialise folder Observation, including files inside the folder.
"""
Initialise a folder observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
:type max_files: int, optional
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
that even if new files are created, the existing files will always occupy the same space in the observation
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
name, it will take the position defined in this dict. Defaults to {}
:type file_positions: Dict[int, str], optional
A typical location for a folder might be ['network', 'nodes', <node_hostname>, 'folders', <folder_name>].
:type where: WhereType
:param files: List of file observation instances within the folder.
:type files: Iterable[FileObservation]
:param num_files: Number of files expected in the folder.
:type num_files: int
:param include_num_access: Whether to include the number of accesses to files in the observation.
:type include_num_access: bool
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.where: WhereType = where
self.files: List[FileObservation] = files
while len(self.files) < num_files_per_folder:
self.files.append(FileObservation())
while len(self.files) > num_files_per_folder:
while len(self.files) < num_files:
self.files.append(FileObservation(where=None, include_num_access=include_num_access))
while len(self.files) > num_files:
truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
_LOGGER.warning(msg)
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
if self.files:
self.default_observation["FILES"] = {i + 1: f.default_observation for i, f in enumerate(self.files)}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the health status of the folder and status of files within the folder.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
@@ -130,48 +174,42 @@ class FolderObservation(AbstractObservation):
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
if self.files:
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:return: Gymnasium space representing the observation space for folder status.
:rtype: spaces.Space
"""
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
shape = {"health_status": spaces.Discrete(6)}
if self.files:
shape["FILES"] = spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)})
return spaces.Dict(shape)
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> FolderObservation:
"""
Create a folder observation from a configuration schema.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param config: Configuration schema containing the necessary information for the folder observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_hostname>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed folder observation
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed folder observation instance.
:rtype: FolderObservation
"""
where = parent_where + ["folders", config["folder_name"]]
where = parent_where + ["folders", config.folder_name]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
# pass down shared/common config items
for file_config in config.files:
file_config.include_num_access = config.include_num_access
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in config.files]
return cls(where=where, files=files, num_files=config.num_files, include_num_access=config.include_num_access)

View File

@@ -0,0 +1,224 @@
from __future__ import annotations
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class FirewallObservation(AbstractObservation, identifier="FIREWALL"):
"""Firewall observation, provides status information about a firewall within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for FirewallObservation."""
hostname: str
"""Hostname of the firewall node, used for querying simulation state dictionary."""
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
def __init__(
self,
where: WhereType,
ip_list: List[str],
wildcard_list: List[str],
port_list: List[int],
protocol_list: List[str],
num_rules: int,
) -> None:
"""
Initialise a firewall observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this firewall.
A typical location for a firewall might be ['network', 'nodes', <firewall_hostname>].
:type where: WhereType
:param ip_list: List of IP addresses.
:type ip_list: List[str]
:param wildcard_list: List of wildcard rules.
:type wildcard_list: List[str]
:param port_list: List of port numbers.
:type port_list: List[int]
:param protocol_list: List of protocol types.
:type protocol_list: List[str]
:param num_rules: Number of rules configured in the firewall.
:type num_rules: int
"""
self.where: WhereType = where
self.ports: List[PortObservation] = [
PortObservation(where=self.where + ["NICs", port_num]) for port_num in (1, 2, 3)
]
# TODO: check what the port nums are for firewall.
self.internal_inbound_acl = ACLObservation(
where=self.where + ["internal_inbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.internal_outbound_acl = ACLObservation(
where=self.where + ["internal_outbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.dmz_inbound_acl = ACLObservation(
where=self.where + ["dmz_inbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.dmz_outbound_acl = ACLObservation(
where=self.where + ["dmz_outbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.external_inbound_acl = ACLObservation(
where=self.where + ["external_inbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.external_outbound_acl = ACLObservation(
where=self.where + ["external_outbound_acl", "acl"],
num_rules=num_rules,
ip_list=ip_list,
wildcard_list=wildcard_list,
port_list=port_list,
protocol_list=protocol_list,
)
self.default_observation = {
"PORTS": {i + 1: p.default_observation for i, p in enumerate(self.ports)},
"ACL": {
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.default_observation,
"OUTBOUND": self.internal_outbound_acl.default_observation,
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.default_observation,
"OUTBOUND": self.dmz_outbound_acl.default_observation,
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.default_observation,
"OUTBOUND": self.external_outbound_acl.default_observation,
},
},
}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status of ports and ACLs for internal, DMZ, and external traffic.
:rtype: ObsType
"""
obs = {
"PORTS": {i + 1: p.observe(state) for i, p in enumerate(self.ports)},
"ACL": {
"INTERNAL": {
"INBOUND": self.internal_inbound_acl.observe(state),
"OUTBOUND": self.internal_outbound_acl.observe(state),
},
"DMZ": {
"INBOUND": self.dmz_inbound_acl.observe(state),
"OUTBOUND": self.dmz_outbound_acl.observe(state),
},
"EXTERNAL": {
"INBOUND": self.external_inbound_acl.observe(state),
"OUTBOUND": self.external_outbound_acl.observe(state),
},
},
}
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for firewall status.
:rtype: spaces.Space
"""
space = spaces.Dict(
{
"PORTS": spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)}),
"ACL": spaces.Dict(
{
"INTERNAL": spaces.Dict(
{
"INBOUND": self.internal_inbound_acl.space,
"OUTBOUND": self.internal_outbound_acl.space,
}
),
"DMZ": spaces.Dict(
{
"INBOUND": self.dmz_inbound_acl.space,
"OUTBOUND": self.dmz_outbound_acl.space,
}
),
"EXTERNAL": spaces.Dict(
{
"INBOUND": self.external_inbound_acl.space,
"OUTBOUND": self.external_outbound_acl.space,
}
),
}
),
}
)
return space
@classmethod
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> FirewallObservation:
"""
Create a firewall observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the firewall observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this firewall's
parent node. A typical location for a node might be ['network', 'nodes', <firewall_hostname>].
:type parent_where: WhereType, optional
:return: Constructed firewall observation instance.
:rtype: FirewallObservation
"""
return cls(
where=parent_where + ["nodes", config.hostname],
ip_list=config.ip_list,
wildcard_list=config.wildcard_list,
port_list=config.port_list,
protocol_list=config.protocol_list,
num_rules=config.num_rules,
)

View File

@@ -0,0 +1,248 @@
from __future__ import annotations
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.file_system_observations import FolderObservation
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class HostObservation(AbstractObservation, identifier="HOST"):
"""Host observation, provides status information about a host within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for HostObservation."""
hostname: str
"""Hostname of the host, used for querying simulation state dictionary."""
services: List[ServiceObservation.ConfigSchema] = []
"""List of services to observe on the host."""
applications: List[ApplicationObservation.ConfigSchema] = []
"""List of applications to observe on the host."""
folders: List[FolderObservation.ConfigSchema] = []
"""List of folders to observe on the host."""
network_interfaces: List[NICObservation.ConfigSchema] = []
"""List of network interfaces to observe on the host."""
num_services: Optional[int] = None
"""Number of spaces for service observations on this host."""
num_applications: Optional[int] = None
"""Number of spaces for application observations on this host."""
num_folders: Optional[int] = None
"""Number of spaces for folder observations on this host."""
num_files: Optional[int] = None
"""Number of spaces for file observations on this host."""
num_nics: Optional[int] = None
"""Number of spaces for network interface observations on this host."""
include_nmne: Optional[bool] = None
"""Whether network interface observations should include number of malicious network events."""
include_num_access: Optional[bool] = None
"""Whether to include the number of accesses to files observations on this host."""
def __init__(
self,
where: WhereType,
services: List[ServiceObservation],
applications: List[ApplicationObservation],
folders: List[FolderObservation],
network_interfaces: List[NICObservation],
num_services: int,
num_applications: int,
num_folders: int,
num_files: int,
num_nics: int,
include_nmne: bool,
include_num_access: bool,
) -> None:
"""
Initialise a host observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this host.
A typical location for a host might be ['network', 'nodes', <hostname>].
:type where: WhereType
:param services: List of service observations on the host.
:type services: List[ServiceObservation]
:param applications: List of application observations on the host.
:type applications: List[ApplicationObservation]
:param folders: List of folder observations on the host.
:type folders: List[FolderObservation]
:param network_interfaces: List of network interface observations on the host.
:type network_interfaces: List[NICObservation]
:param num_services: Number of services to observe.
:type num_services: int
:param num_applications: Number of applications to observe.
:type num_applications: int
:param num_folders: Number of folders to observe.
:type num_folders: int
:param num_files: Number of files.
:type num_files: int
:param num_nics: Number of network interfaces.
:type num_nics: int
:param include_nmne: Flag to include network metrics and errors.
:type include_nmne: bool
:param include_num_access: Flag to include the number of accesses to files.
:type include_num_access: bool
"""
self.where: WhereType = where
self.include_num_access = include_num_access
# Ensure lists have lengths equal to specified counts by truncating or padding
self.services: List[ServiceObservation] = services
while len(self.services) < num_services:
self.services.append(ServiceObservation(where=None))
while len(self.services) > num_services:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warning(msg)
self.applications: List[ApplicationObservation] = applications
while len(self.applications) < num_applications:
self.applications.append(ApplicationObservation(where=None))
while len(self.applications) > num_applications:
truncated_application = self.applications.pop()
msg = f"Too many applications in Node observation space for node. Truncating {truncated_application.where}"
_LOGGER.warning(msg)
self.folders: List[FolderObservation] = folders
while len(self.folders) < num_folders:
self.folders.append(
FolderObservation(where=None, files=[], num_files=num_files, include_num_access=include_num_access)
)
while len(self.folders) > num_folders:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation space for node. Truncating folder {truncated_folder.where}"
_LOGGER.warning(msg)
self.nics: List[NICObservation] = network_interfaces
while len(self.nics) < num_nics:
self.nics.append(NICObservation(where=None, include_nmne=include_nmne))
while len(self.nics) > num_nics:
truncated_nic = self.nics.pop()
msg = f"Too many network_interfaces in Node observation space for node. Truncating {truncated_nic.where}"
_LOGGER.warning(msg)
self.default_observation: ObsType = {
"operating_status": 0,
}
if self.services:
self.default_observation["SERVICES"] = {i + 1: s.default_observation for i, s in enumerate(self.services)}
if self.applications:
self.default_observation["APPLICATIONS"] = {
i + 1: a.default_observation for i, a in enumerate(self.applications)
}
if self.folders:
self.default_observation["FOLDERS"] = {i + 1: f.default_observation for i, f in enumerate(self.folders)}
if self.nics:
self.default_observation["NICS"] = {i + 1: n.default_observation for i, n in enumerate(self.nics)}
if self.include_num_access:
self.default_observation["num_file_creations"] = 0
self.default_observation["num_file_deletions"] = 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status information about the host.
:rtype: ObsType
"""
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["operating_status"] = node_state["operating_state"]
if self.services:
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
if self.applications:
obs["APPLICATIONS"] = {i + 1: app.observe(state) for i, app in enumerate(self.applications)}
if self.folders:
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
if self.nics:
obs["NICS"] = {i + 1: nic.observe(state) for i, nic in enumerate(self.nics)}
if self.include_num_access:
obs["num_file_creations"] = node_state["file_system"]["num_file_creations"]
obs["num_file_deletions"] = node_state["file_system"]["num_file_deletions"]
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for host status.
:rtype: spaces.Space
"""
shape = {
"operating_status": spaces.Discrete(5),
}
if self.services:
shape["SERVICES"] = spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)})
if self.applications:
shape["APPLICATIONS"] = spaces.Dict({i + 1: app.space for i, app in enumerate(self.applications)})
if self.folders:
shape["FOLDERS"] = spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)})
if self.nics:
shape["NICS"] = spaces.Dict({i + 1: nic.space for i, nic in enumerate(self.nics)})
if self.include_num_access:
shape["num_file_creations"] = spaces.Discrete(4)
shape["num_file_deletions"] = spaces.Discrete(4)
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> HostObservation:
"""
Create a host observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the host observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this host.
A typical location might be ['network', 'nodes', <hostname>].
:type parent_where: WhereType, optional
:return: Constructed host observation instance.
:rtype: HostObservation
"""
if parent_where == []:
where = ["network", "nodes", config.hostname]
else:
where = parent_where + ["nodes", config.hostname]
# Pass down shared/common config items
for folder_config in config.folders:
folder_config.include_num_access = config.include_num_access
folder_config.num_files = config.num_files
for nic_config in config.network_interfaces:
nic_config.include_nmne = config.include_nmne
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in config.services]
applications = [
ApplicationObservation.from_config(config=c, game=game, parent_where=where) for c in config.applications
]
folders = [FolderObservation.from_config(config=c, game=game, parent_where=where) for c in config.folders]
nics = [NICObservation.from_config(config=c, game=game, parent_where=where) for c in config.network_interfaces]
return cls(
where=where,
services=services,
applications=applications,
folders=folders,
network_interfaces=nics,
num_services=config.num_services,
num_applications=config.num_applications,
num_folders=config.num_folders,
num_files=config.num_files,
num_nics=config.num_nics,
include_nmne=config.include_nmne,
include_num_access=config.include_num_access,
)

View File

@@ -0,0 +1,155 @@
from __future__ import annotations
from typing import Any, Dict, List, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class LinkObservation(AbstractObservation, identifier="LINK"):
"""Link observation, providing information about a specific link within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for LinkObservation."""
link_reference: str
"""Reference identifier for the link."""
def __init__(self, where: WhereType) -> None:
"""
Initialise a link observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this link.
A typical location for a link might be ['network', 'links', <link_reference>].
:type where: WhereType
"""
self.where = where
self.default_observation: ObsType = {"PROTOCOLS": {"ALL": 0}}
def observe(self, state: Dict) -> Any:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing information about the link.
:rtype: Any
"""
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
utilisation_category = int(utilisation_fraction * 9) + 1
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for link status.
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinkObservation:
"""
Create a link observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the link observation.
:type config: ConfigSchema
:param game: The PrimaiteGame instance.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this link.
A typical location might be ['network', 'links', <link_reference>].
:type parent_where: WhereType, optional
:return: Constructed link observation instance.
:rtype: LinkObservation
"""
link_reference = game.ref_map_links[config.link_reference]
if parent_where == []:
where = ["network", "links", link_reference]
else:
where = parent_where + ["links", link_reference]
return cls(where=where)
class LinksObservation(AbstractObservation, identifier="LINKS"):
"""Collection of link observations representing multiple links within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for LinksObservation."""
link_references: List[str]
"""List of reference identifiers for the links."""
def __init__(self, where: WhereType, links: List[LinkObservation]) -> None:
"""
Initialise a links observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for these links.
A typical location for links might be ['network', 'links'].
:type where: WhereType
:param links: List of link observations.
:type links: List[LinkObservation]
"""
self.where: WhereType = where
self.links: List[LinkObservation] = links
self.default_observation: ObsType = {i + 1: l.default_observation for i, l in enumerate(self.links)}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing information about multiple links.
:rtype: ObsType
"""
return {i + 1: l.observe(state) for i, l in enumerate(self.links)}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for multiple links.
:rtype: spaces.Space
"""
return spaces.Dict({i + 1: l.space for i, l in enumerate(self.links)})
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> LinksObservation:
"""
Create a links observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the links observation.
:type config: ConfigSchema
:param game: The PrimaiteGame instance.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about these links.
A typical location might be ['network'].
:type parent_where: WhereType, optional
:return: Constructed links observation instance.
:rtype: LinksObservation
"""
where = parent_where + ["network"]
link_cfgs = [LinkObservation.ConfigSchema(link_reference=ref) for ref in config.link_references]
links = [LinkObservation.from_config(c, game=game, parent_where=where) for c in link_cfgs]
return cls(where=where, links=links)

View File

@@ -1,97 +1,56 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import CAPTURE_NMNE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
class NICObservation(AbstractObservation, identifier="NETWORK_INTERFACE"):
"""Status information about a network interface within the simulation environment."""
low_nmne_threshold: int = 0
"""The minimum number of malicious network events to be considered low."""
med_nmne_threshold: int = 5
"""The minimum number of malicious network events to be considered medium."""
high_nmne_threshold: int = 10
"""The minimum number of malicious network events to be considered high."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NICObservation."""
global CAPTURE_NMNE
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
return data
nic_num: int
"""Number of the network interface."""
include_nmne: Optional[bool] = None
"""Whether to include number of malicious network events (NMNE) in the observation."""
def __init__(
self,
where: Optional[Tuple[str]] = None,
low_nmne_threshold: Optional[int] = 0,
med_nmne_threshold: Optional[int] = 5,
high_nmne_threshold: Optional[int] = 10,
where: WhereType,
include_nmne: bool,
) -> None:
"""Initialise NIC observation.
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_hostname>,'NICs',<nic_number>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialise a network interface observation instance.
global CAPTURE_NMNE
if CAPTURE_NMNE:
:param where: Where in the simulation state dictionary to find the relevant information for this interface.
A typical location for a network interface might be
['network', 'nodes', <node_hostname>, 'NICs', <nic_num>].
:type where: WhereType
:param include_nmne: Flag to determine whether to include NMNE information in the observation.
:type include_nmne: bool
"""
self.where = where
self.include_nmne: bool = include_nmne
self.default_observation: ObsType = {"nic_status": 0}
if self.include_nmne:
self.default_observation.update({"NMNE": {"inbound": 0, "outbound": 0}})
self.nmne_inbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
self.nmne_outbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
self._validate_nmne_categories(
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)
def _validate_nmne_categories(
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
):
"""
Validates the nmne threshold config.
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
"""
if high_nmne_threshold <= med_nmne_threshold:
raise Exception(
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
f"than medium nmne count ({med_nmne_threshold})"
)
if med_nmne_threshold <= low_nmne_threshold:
raise Exception(
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
f"than low nmne count ({low_nmne_threshold})"
)
self.high_nmne_threshold = high_nmne_threshold
self.med_nmne_threshold = med_nmne_threshold
self.low_nmne_threshold = low_nmne_threshold
# TODO: allow these to be configured in yaml
self.high_nmne_threshold = 10
self.med_nmne_threshold = 5
self.low_nmne_threshold = 0
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
@@ -116,73 +75,120 @@ class NicObservation(AbstractObservation):
return 1
return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status of the network interface and optionally NMNE information.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
if CAPTURE_NMNE:
obs_dict.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs_dict
obs = {"nic_status": 1 if nic_state["enabled"] else 2}
if self.include_nmne:
obs.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for network interface status and NMNE information.
:rtype: spaces.Space
"""
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
if CAPTURE_NMNE:
if self.include_nmne:
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
return space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NICObservation:
"""
low_nmne_threshold = None
med_nmne_threshold = None
high_nmne_threshold = None
Create a network interface observation from a configuration schema.
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
threshold = game.options.thresholds["nmne"]
:param config: Configuration schema containing the necessary information for the network interface observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed network interface observation instance.
:rtype: NICObservation
"""
return cls(where=parent_where + ["NICs", config.nic_num], include_nmne=config.include_nmne)
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
return cls(
where=parent_where + ["NICs", config["nic_num"]],
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)
class PortObservation(AbstractObservation, identifier="PORT"):
"""Port observation, provides status information about a network port within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for PortObservation."""
port_id: int
"""Identifier of the port, used for querying simulation state dictionary."""
def __init__(self, where: WhereType) -> None:
"""
Initialise a port observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this port.
A typical location for a port might be ['network', 'nodes', <node_hostname>, 'NICs', <port_id>].
:type where: WhereType
"""
self.where = where
self.default_observation: ObsType = {"operating_status": 0}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the operating status of the port.
:rtype: ObsType
"""
port_state = access_from_nested_dict(state, self.where)
if port_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"operating_status": 1 if port_state["enabled"] else 2}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for port status.
:rtype: spaces.Space
"""
return spaces.Dict({"operating_status": spaces.Discrete(3)})
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> PortObservation:
"""
Create a port observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the port observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this port's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed port observation instance.
:rtype: PortObservation
"""
return cls(where=parent_where + ["NICs", config.port_id])

View File

@@ -1,200 +1,218 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import model_validator
from primaite import getLogger
from primaite.game.agent.observations.file_system_observations import FolderObservation
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.software_observation import ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
from primaite.game.agent.observations.firewall_observation import FirewallObservation
from primaite.game.agent.observations.host_observations import HostObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.observations.router_observation import RouterObservation
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class NodeObservation(AbstractObservation):
"""Observation of a node in the network. Includes services, folders and NICs."""
class NodesObservation(AbstractObservation, identifier="NODES"):
"""Nodes observation, provides status information about nodes within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NodesObservation."""
hosts: List[HostObservation.ConfigSchema] = []
"""List of configurations for host observations."""
routers: List[RouterObservation.ConfigSchema] = []
"""List of configurations for router observations."""
firewalls: List[FirewallObservation.ConfigSchema] = []
"""List of configurations for firewall observations."""
num_services: Optional[int] = None
"""Number of services."""
num_applications: Optional[int] = None
"""Number of applications."""
num_folders: Optional[int] = None
"""Number of folders."""
num_files: Optional[int] = None
"""Number of files."""
num_nics: Optional[int] = None
"""Number of network interface cards (NICs)."""
include_nmne: Optional[bool] = None
"""Flag to include nmne."""
include_num_access: Optional[bool] = None
"""Flag to include the number of accesses."""
num_ports: Optional[int] = None
"""Number of ports."""
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
@model_validator(mode="after")
def force_optional_fields(self) -> NodesObservation.ConfigSchema:
"""Check that options are specified only if they are needed for the nodes that are part of the config."""
# check for hosts:
host_fields = (
self.num_services,
self.num_applications,
self.num_folders,
self.num_files,
self.num_nics,
self.include_nmne,
self.include_num_access,
)
router_fields = (
self.num_ports,
self.ip_list,
self.wildcard_list,
self.port_list,
self.protocol_list,
self.num_rules,
)
firewall_fields = (self.ip_list, self.wildcard_list, self.port_list, self.protocol_list, self.num_rules)
if len(self.hosts) > 0 and any([x is None for x in host_fields]):
raise ValueError("Configuration error: Host observation options were not fully specified.")
if len(self.routers) > 0 and any([x is None for x in router_fields]):
raise ValueError("Configuration error: Router observation options were not fully specified.")
if len(self.firewalls) > 0 and any([x is None for x in firewall_fields]):
raise ValueError("Configuration error: Firewall observation options were not fully specified.")
return self
def __init__(
self,
where: Optional[Tuple[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
network_interfaces: List[NicObservation] = [],
logon_status: bool = False,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
where: WhereType,
hosts: List[HostObservation],
routers: List[RouterObservation],
firewalls: List[FirewallObservation],
) -> None:
"""
Configurable observation for a node in the simulation.
Initialise a nodes observation instance.
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service name, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
:type max_services: int, optional
:param folders: Mapping between position in observation space and folder name, defaults to {}
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {}
:type network_interfaces: Dict[int,str], optional
:param max_nics: Max number of network interfaces in this node's obs space, defaults to 5
:type max_nics: int, optional
:param where: Where in the simulation state dictionary to find the relevant information for nodes.
A typical location for nodes might be ['network', 'nodes'].
:type where: WhereType
:param hosts: List of host observations.
:type hosts: List[HostObservation]
:param routers: List of router observations.
:type routers: List[RouterObservation]
:param firewalls: List of firewall observations.
:type firewalls: List[FirewallObservation]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.where: WhereType = where
self.services: List[ServiceObservation] = services
while len(self.services) < num_services_per_node:
# add empty service observation without `where` parameter so it always returns default (blank) observation
self.services.append(ServiceObservation())
while len(self.services) > num_services_per_node:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warning(msg)
# truncate service list
self.hosts: List[HostObservation] = hosts
self.routers: List[RouterObservation] = routers
self.firewalls: List[FirewallObservation] = firewalls
self.folders: List[FolderObservation] = folders
# add empty folder observation without `where` parameter that will always return default (blank) observations
while len(self.folders) < num_folders_per_node:
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warning(msg)
self.network_interfaces: List[NicObservation] = network_interfaces
while len(self.network_interfaces) < num_nics_per_node:
self.network_interfaces.append(NicObservation())
while len(self.network_interfaces) > num_nics_per_node:
truncated_nic = self.network_interfaces.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warning(msg)
self.logon_status: bool = logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
self.default_observation = {
**{f"HOST{i}": host.default_observation for i, host in enumerate(self.hosts)},
**{f"ROUTER{i}": router.default_observation for i, router in enumerate(self.routers)},
**{f"FIREWALL{i}": firewall.default_observation for i, firewall in enumerate(self.firewalls)},
}
if self.logon_status:
self.default_observation["logon_status"] = 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing status information about nodes.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
obs = {
**{f"HOST{i}": host.observe(state) for i, host in enumerate(self.hosts)},
**{f"ROUTER{i}": router.observe(state) for i, router in enumerate(self.routers)},
**{f"FIREWALL{i}": firewall.observe(state) for i, firewall in enumerate(self.firewalls)},
}
if self.logon_status:
obs["logon_status"] = 0
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
space_shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
}
if self.logon_status:
space_shape["logon_status"] = spaces.Discrete(3)
"""
Gymnasium space object describing the observation space shape.
return spaces.Dict(space_shape)
:return: Gymnasium space representing the observation space for nodes.
:rtype: spaces.Space
"""
space = spaces.Dict(
{
**{f"HOST{i}": host.space for i, host in enumerate(self.hosts)},
**{f"ROUTER{i}": router.space for i, router in enumerate(self.routers)},
**{f"FIREWALL{i}": firewall.space for i, firewall in enumerate(self.firewalls)},
}
)
return space
@classmethod
def from_config(
cls,
config: Dict,
game: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> "NodeObservation":
"""Create node observation from a config. Also creates child service, folder and NIC observations.
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
observation size) , defaults to 2
:type num_services_per_node: int, optional
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
observation size) , defaults to 2
:type num_folders_per_node: int, optional
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed node observation
:rtype: NodeObservation
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NodesObservation:
"""
node_hostname = config["node_hostname"]
if parent_where is None:
where = ["network", "nodes", node_hostname]
else:
where = parent_where + ["nodes", node_hostname]
Create a nodes observation from a configuration schema.
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
services=services,
folders=folders,
network_interfaces=network_interfaces,
logon_status=logon_status,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
:param config: Configuration schema containing the necessary information for nodes observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about nodes.
A typical location for nodes might be ['network', 'nodes'].
:type parent_where: WhereType, optional
:return: Constructed nodes observation instance.
:rtype: NodesObservation
"""
if parent_where is None:
where = ["network", "nodes"]
else:
where = parent_where + ["nodes"]
for host_config in config.hosts:
if host_config.num_services is None:
host_config.num_services = config.num_services
if host_config.num_applications is None:
host_config.num_applications = config.num_applications
if host_config.num_folders is None:
host_config.num_folders = config.num_folders
if host_config.num_files is None:
host_config.num_files = config.num_files
if host_config.num_nics is None:
host_config.num_nics = config.num_nics
if host_config.include_nmne is None:
host_config.include_nmne = config.include_nmne
if host_config.include_num_access is None:
host_config.include_num_access = config.include_num_access
for router_config in config.routers:
if router_config.num_ports is None:
router_config.num_ports = config.num_ports
if router_config.ip_list is None:
router_config.ip_list = config.ip_list
if router_config.wildcard_list is None:
router_config.wildcard_list = config.wildcard_list
if router_config.port_list is None:
router_config.port_list = config.port_list
if router_config.protocol_list is None:
router_config.protocol_list = config.protocol_list
if router_config.num_rules is None:
router_config.num_rules = config.num_rules
for firewall_config in config.firewalls:
if firewall_config.ip_list is None:
firewall_config.ip_list = config.ip_list
if firewall_config.wildcard_list is None:
firewall_config.wildcard_list = config.wildcard_list
if firewall_config.port_list is None:
firewall_config.port_list = config.port_list
if firewall_config.protocol_list is None:
firewall_config.protocol_list = config.protocol_list
if firewall_config.num_rules is None:
firewall_config.num_rules = config.num_rules
hosts = [HostObservation.from_config(config=c, game=game, parent_where=where) for c in config.hosts]
routers = [RouterObservation.from_config(config=c, game=game, parent_where=where) for c in config.routers]
firewalls = [FirewallObservation.from_config(config=c, game=game, parent_where=where) for c in config.firewalls]
return cls(where=where, hosts=hosts, routers=routers, firewalls=firewalls)

View File

@@ -1,18 +1,149 @@
from typing import Dict, TYPE_CHECKING
from __future__ import annotations
from typing import Any, Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict, model_validator, ValidationError
from primaite.game.agent.observations.agent_observations import (
UC2BlueObservation,
UC2GreenObservation,
UC2RedObservation,
)
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NestedObservation(AbstractObservation, identifier="CUSTOM"):
"""Observation type that allows combining other observations into a gymnasium.spaces.Dict space."""
class NestedObservationItem(BaseModel):
"""One list item of the config."""
model_config = ConfigDict(extra="forbid")
type: str
"""Select observation class. It maps to the identifier of the obs class by checking the registry."""
label: str
"""Dict key in the final observation space."""
options: Dict
"""Options to pass to the observation class from_config method."""
@model_validator(mode="after")
def check_model(self) -> "NestedObservation.NestedObservationItem":
"""Make sure tha the config options match up with the selected observation type."""
obs_subclass_name = self.type
obs_options = self.options
if obs_subclass_name not in AbstractObservation._registry:
raise ValueError(f"Observation of type {obs_subclass_name} could not be found.")
obs_schema = AbstractObservation._registry[obs_subclass_name].ConfigSchema
try:
obs_schema(**obs_options)
except ValidationError as e:
raise ValueError(f"Observation options did not match schema, got this error: {e}")
return self
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for NestedObservation."""
components: List[NestedObservation.NestedObservationItem] = []
"""List of observation components to be part of this space."""
def __init__(self, components: Dict[str, AbstractObservation]) -> None:
"""Initialise nested observation."""
self.components: Dict[str, AbstractObservation] = components
"""Maps label: observation object"""
self.default_observation = {label: obs.default_observation for label, obs in self.components.items()}
"""Default observation is just the default observations of constituents."""
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status information about the host.
:rtype: ObsType
"""
return {label: obs.observe(state) for label, obs in self.components.items()}
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the nested observation space.
:rtype: spaces.Space
"""
return spaces.Dict({label: obs.space for label, obs in self.components.items()})
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> NestedObservation:
"""
Read the Nested observation config and create all defined subcomponents.
Example configuration that utilises NestedObservation:
This lets us have different options for different types of hosts.
```yaml
observation_space:
- type: CUSTOM
options:
components:
- type: HOSTS
label: COMPUTERS # What is the dictionary key called
options:
hosts:
- client_1
- client_2
num_services: 0
num_applications: 5
... # other options
- type: HOSTS
label: SERVERS # What is the dictionary key called
options:
hosts:
- hostname: database_server
- hostname: web_server
num_services: 4
num_applications: 0
num_folders: 2
num_files: 2
```
"""
instances = dict()
for component in config.components:
obs_class = AbstractObservation._registry[component.type]
obs_instance = obs_class.from_config(config=obs_class.ConfigSchema(**component.options), game=game)
instances[component.label] = obs_instance
return cls(components=instances)
class NullObservation(AbstractObservation, identifier="NONE"):
"""Empty observation that acts as a placeholder."""
def __init__(self) -> None:
"""Initialise the empty observation."""
self.default_observation = 0
def observe(self, state: Dict) -> Any:
"""Simply return 0."""
return 0
@property
def space(self) -> spaces.Space:
"""Essentially empty space."""
return spaces.Discrete(1)
@classmethod
def from_config(
cls, config: NullObservation.ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> NullObservation:
"""Instantiate a NullObservation. Accepts parameters to comply with API."""
return cls()
class ObservationManager:
"""
Manage the observations of an Agent.
@@ -23,18 +154,15 @@ class ObservationManager:
3. Formatting this information so an agent can use it to make decisions.
"""
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
# refactor.
def __init__(self, observation: AbstractObservation) -> None:
def __init__(self, obs: AbstractObservation) -> None:
"""Initialise observation space.
:param observation: Observation object
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.obs: AbstractObservation = obs
self.current_observation: ObsType
"""Cached copy of the observation at the time it was most recently calculated."""
def update(self, state: Dict) -> Dict:
"""
@@ -52,22 +180,25 @@ class ObservationManager:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
def from_config(cls, config: Optional[Dict], game: "PrimaiteGame") -> "ObservationManager":
"""
Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
It should contain the key 'type' which selects which observation class to use (from a choice of:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
If None, a blank observation space is created.
Otherwise, this must be a Dict with a type field and options field.
type: string that corresponds to one of the observation identifiers that are provided when subclassing
AbstractObservation
options: this must adhere to the chosen observation type's ConfigSchema nested class.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
else:
raise ValueError("Observation space type invalid")
if config is None:
return cls(NullObservation())
print(config)
obs_type = config["type"]
obs_class = AbstractObservation._registry[obs_type]
observation = obs_class.from_config(config=obs_class.ConfigSchema(**config["options"]), game=game)
obs_manager = cls(observation)
return obs_manager

View File

@@ -1,22 +1,50 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import Any, Dict, Iterable, Type, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
WhereType = Iterable[str | int] | None
class AbstractObservation(ABC):
"""Abstract class for an observation space component."""
class ConfigSchema(ABC, BaseModel):
"""Config schema for observations."""
model_config = ConfigDict(extra="forbid")
_registry: Dict[str, Type["AbstractObservation"]] = {}
"""Registry of observation components, with their name as key.
Automatically populated when subclasses are defined. Used for defining from_config.
"""
def __init__(self) -> None:
"""Initialise an observation. This method must be overwritten."""
self.default_observation: ObsType
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register an observation type.
:param identifier: Identifier used to uniquely specify observation component types.
:type identifier: str
:raises ValueError: When attempting to create a component with a name that is already in use.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate observation component type {identifier}")
cls._registry[identifier] = cls
@abstractmethod
def observe(self, state: Dict) -> Any:
"""
@@ -37,273 +65,8 @@ class AbstractObservation(ABC):
@classmethod
@abstractmethod
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `game` parameter is for a the PrimaiteGame object that spawns this component.
"""
pass
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_hostname"]
nic_num = ip_map_config["nic_num"]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.network_interface[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
def from_config(
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> "AbstractObservation":
"""Create this observation space component form a serialised format."""
return cls()
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
pass

View File

@@ -0,0 +1,147 @@
from __future__ import annotations
from typing import Dict, List, Optional, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
_LOGGER = getLogger(__name__)
class RouterObservation(AbstractObservation, identifier="ROUTER"):
"""Router observation, provides status information about a router within the simulation environment."""
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for RouterObservation."""
hostname: str
"""Hostname of the router, used for querying simulation state dictionary."""
ports: Optional[List[PortObservation.ConfigSchema]] = None
"""Configuration of port observations for this router."""
num_ports: Optional[int] = None
"""Number of port observations configured for this router."""
acl: Optional[ACLObservation.ConfigSchema] = None
"""Configuration of ACL observation on this router."""
ip_list: Optional[List[str]] = None
"""List of IP addresses for encoding ACLs."""
wildcard_list: Optional[List[str]] = None
"""List of IP wildcards for encoding ACLs."""
port_list: Optional[List[int]] = None
"""List of ports for encoding ACLs."""
protocol_list: Optional[List[str]] = None
"""List of protocols for encoding ACLs."""
num_rules: Optional[int] = None
"""Number of rules ACL rules to show."""
def __init__(
self,
where: WhereType,
ports: List[PortObservation],
num_ports: int,
acl: ACLObservation,
) -> None:
"""
Initialise a router observation instance.
:param where: Where in the simulation state dictionary to find the relevant information for this router.
A typical location for a router might be ['network', 'nodes', <node_hostname>].
:type where: WhereType
:param ports: List of port observations representing the ports of the router.
:type ports: List[PortObservation]
:param num_ports: Number of ports for the router.
:type num_ports: int
:param acl: ACL observation representing the access control list of the router.
:type acl: ACLObservation
"""
self.where: WhereType = where
self.ports: List[PortObservation] = ports
self.acl: ACLObservation = acl
self.num_ports: int = num_ports
while len(self.ports) < num_ports:
self.ports.append(PortObservation(where=None))
while len(self.ports) > num_ports:
self.ports.pop()
msg = "Too many ports in router observation. Truncating."
_LOGGER.warning(msg)
self.default_observation = {
"ACL": self.acl.default_observation,
}
if self.ports:
self.default_observation["PORTS"] = {i + 1: p.default_observation for i, p in enumerate(self.ports)}
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation containing the status of ports and ACL configuration of the router.
:rtype: ObsType
"""
router_state = access_from_nested_dict(state, self.where)
if router_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["ACL"] = self.acl.observe(state)
if self.ports:
obs["PORTS"] = {i + 1: p.observe(state) for i, p in enumerate(self.ports)}
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for router status.
:rtype: spaces.Space
"""
shape = {"ACL": self.acl.space}
if self.ports:
shape["PORTS"] = spaces.Dict({i + 1: p.space for i, p in enumerate(self.ports)})
return spaces.Dict(shape)
@classmethod
def from_config(cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []) -> RouterObservation:
"""
Create a router observation from a configuration schema.
:param config: Configuration schema containing the necessary information for the router observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this router's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed router observation instance.
:rtype: RouterObservation
"""
where = parent_where + ["nodes", config.hostname]
if config.acl is None:
config.acl = ACLObservation.ConfigSchema()
if config.acl.num_rules is None:
config.acl.num_rules = config.num_rules
if config.acl.ip_list is None:
config.acl.ip_list = config.ip_list
if config.acl.wildcard_list is None:
config.acl.wildcard_list = config.wildcard_list
if config.acl.port_list is None:
config.acl.port_list = config.port_list
if config.acl.protocol_list is None:
config.acl.protocol_list = config.protocol_list
if config.ports is None:
config.ports = [PortObservation.ConfigSchema(port_id=i + 1) for i in range(config.num_ports)]
ports = [PortObservation.from_config(config=c, game=game, parent_where=where) for c in config.ports]
acl = ACLObservation.from_config(config=config.acl, game=game, parent_where=where)
return cls(where=where, ports=ports, num_ports=config.num_ports, acl=acl)

View File

@@ -1,45 +1,46 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from __future__ import annotations
from typing import Dict, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.observations import AbstractObservation, WhereType
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class ServiceObservation(AbstractObservation):
"""Observation of a service in the network."""
class ServiceObservation(AbstractObservation, identifier="SERVICE"):
"""Service observation, shows status of a service in the simulation environment."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ServiceObservation."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise service observation.
service_name: str
"""Name of the service, used for querying simulation state dictionary"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'services', <service_name>]`
:type where: Optional[List[str]]
def __init__(self, where: WhereType) -> None:
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialise a service observation instance.
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param where: Where in the simulation state dictionary to find the relevant information for this service.
A typical location for a service might be ['network', 'nodes', <node_hostname>, 'services', <service_name>].
:type where: WhereType
"""
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0}
:param state: Simulation state dictionary
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Observation containing the operating status and health status of the service.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
@@ -50,114 +51,120 @@ class ServiceObservation(AbstractObservation):
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for service status.
:rtype: spaces.Space
"""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> ServiceObservation:
"""
Create a service observation from a configuration schema.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:param config: Configuration schema containing the necessary information for the service observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this service's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed service observation instance.
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", config["service_name"]])
return cls(where=parent_where + ["services", config.service_name])
class ApplicationObservation(AbstractObservation):
"""Observation of an application in the network."""
class ApplicationObservation(AbstractObservation, identifier="APPLICATION"):
"""Application observation, shows the status of an application within the simulation environment."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0}
"Default observation is what should be returned when the application doesn't exist."
class ConfigSchema(AbstractObservation.ConfigSchema):
"""Configuration schema for ApplicationObservation."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise application observation.
application_name: str
"""Name of the application, used for querying simulation state dictionary"""
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'applications', <application_name>]`
:type where: Optional[List[str]]
def __init__(self, where: WhereType) -> None:
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
Initialise an application observation instance.
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param where: Where in the simulation state dictionary to find the relevant information for this application.
A typical location for an application might be
['network', 'nodes', <node_hostname>, 'applications', <application_name>].
:type where: WhereType
"""
self.where = where
self.default_observation = {"operating_status": 0, "health_status": 0, "num_executions": 0}
:param state: Simulation state dictionary
# TODO: allow these to be configured in yaml
self.high_threshold = 10
self.med_threshold = 5
self.low_threshold = 0
def _categorise_num_executions(self, num_executions: int) -> int:
"""
Represent number of file accesses as a categorical variable.
:param num_access: Number of file accesses.
:return: Bin number corresponding to the number of accesses.
"""
if num_executions > self.high_threshold:
return 3
elif num_executions > self.med_threshold:
return 2
elif num_executions > self.low_threshold:
return 1
return 0
def observe(self, state: Dict) -> ObsType:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary.
:type state: Dict
:return: Observation
:rtype: Dict
:return: Obs containing the operating status, health status, and number of executions of the application.
:rtype: ObsType
"""
if self.where is None:
return self.default_observation
app_state = access_from_nested_dict(state, self.where)
if app_state is NOT_PRESENT_IN_STATE:
application_state = access_from_nested_dict(state, self.where)
if application_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {
"operating_status": app_state["operating_state"],
"health_status": app_state["health_state_visible"],
"num_executions": self._categorise_num_executions(app_state["num_executions"]),
"operating_status": application_state["operating_state"],
"health_status": application_state["health_state_visible"],
"num_executions": self._categorise_num_executions(application_state["num_executions"]),
}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
"""
Gymnasium space object describing the observation space shape.
:return: Gymnasium space representing the observation space for application status.
:rtype: spaces.Space
"""
return spaces.Dict(
{
"operating_status": spaces.Discrete(7),
"health_status": spaces.Discrete(6),
"health_status": spaces.Discrete(5),
"num_executions": spaces.Discrete(4),
}
)
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ApplicationObservation":
"""Create application observation from a config.
cls, config: ConfigSchema, game: "PrimaiteGame", parent_where: WhereType = []
) -> ApplicationObservation:
"""
Create an application observation from a configuration schema.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:param config: Configuration schema containing the necessary information for the application observation.
:type config: ConfigSchema
:param parent_where: Where in the simulation state dictionary to find the information about this application's
parent node. A typical location for a node might be ['network', 'nodes', <node_hostname>].
:type parent_where: WhereType, optional
:return: Constructed application observation instance.
:rtype: ApplicationObservation
"""
return cls(where=parent_where + ["services", config["application_name"]])
@classmethod
def _categorise_num_executions(cls, num_executions: int) -> int:
"""
Categorise the number of executions of an application.
Helps classify the number of application executions into different categories.
Current categories:
- 0: Application is never executed
- 1: Application is executed a low number of times (1-5)
- 2: Application is executed often (6-10)
- 3: Application is executed a high number of times (more than 10)
:param: num_executions: Number of times the application is executed
"""
if num_executions > 10:
return 3
elif num_executions > 5:
return 2
elif num_executions > 0:
return 1
return 0
return cls(where=parent_where + ["applications", config.application_name])

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Hashable, Sequence
from typing import Any, Dict, Hashable, Optional, Sequence
NOT_PRESENT_IN_STATE = object()
"""
@@ -7,7 +7,7 @@ the thing requested in the state could equal None. This NOT_PRESENT_IN_STATE is
"""
def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
def access_from_nested_dict(dictionary: Dict, keys: Optional[Sequence[Hashable]]) -> Any:
"""
Access an item from a deeply dictionary with a list of keys.
@@ -21,6 +21,8 @@ def access_from_nested_dict(dictionary: Dict, keys: Sequence[Hashable]) -> Any:
:return: The value in the dictionary
:rtype: Any
"""
if keys is None:
return NOT_PRESENT_IN_STATE
key_list = [*keys] # copy keys to a new list to prevent editing original list
if len(key_list) == 0:
return dictionary

View File

@@ -148,8 +148,10 @@ class ACLRule(SimComponent):
state["action"] = self.action.value
state["protocol"] = self.protocol.name if self.protocol else None
state["src_ip_address"] = str(self.src_ip_address) if self.src_ip_address else None
state["src_wildcard_mask"] = str(self.src_wildcard_mask) if self.src_wildcard_mask else None
state["src_port"] = self.src_port.name if self.src_port else None
state["dst_ip_address"] = str(self.dst_ip_address) if self.dst_ip_address else None
state["dst_wildcard_mask"] = str(self.dst_wildcard_mask) if self.dst_wildcard_mask else None
state["dst_port"] = self.dst_port.name if self.dst_port else None
state["match_count"] = self.match_count
return state

View File

@@ -6,7 +6,7 @@ CAPTURE_NMNE: bool = True
NMNE_CAPTURE_KEYWORDS: List[str] = []
"""List of keywords to identify malicious network events."""
# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically
# TODO: Remove final and make configurable after example layout when the NICObservation creates nmne structure dynamically
CAPTURE_BY_DIRECTION: Final[bool] = True
"""Flag to determine if captures should be organized by traffic direction (inbound/outbound)."""
CAPTURE_BY_IP_ADDRESS: Final[bool] = False

View File

@@ -22,8 +22,7 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -50,10 +49,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -86,63 +82,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -41,8 +41,7 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING

View File

@@ -41,8 +41,7 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING

View File

@@ -66,8 +66,7 @@ agents:
- ref: client_1_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING

View File

@@ -26,8 +26,7 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -58,10 +57,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -102,63 +98,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -64,25 +64,67 @@ agents:
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: client_1
links:
- link_ref: client_1___switch_1
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: client_1
nic_num: 1
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: client_1
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.0.10
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- client_1___switch_1
- type: "NONE"
label: ICS
options: {}
# observation_space:
# type: UC2BlueObservation
# options:
# num_services_per_node: 1
# num_folders_per_node: 1
# num_files_per_folder: 1
# num_nics_per_node: 2
# nodes:
# - node_hostname: client_1
# links:
# - link_ref: client_1___switch_1
# acl:
# options:
# max_acl_rules: 10
# router_hostname: router_1
# ip_address_order:
# - node_hostname: client_1
# nic_num: 1
# ics: null
action_space:
action_list:
- type: DONOTHING

View File

@@ -32,8 +32,7 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -61,10 +60,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -97,63 +93,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
@@ -553,63 +559,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -41,8 +41,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -91,8 +90,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -141,10 +139,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -177,61 +172,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -41,8 +41,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -91,8 +90,7 @@ agents:
0: 0.3
1: 0.6
2: 0.1
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -141,10 +139,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -177,61 +172,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: DNSServer
- node_hostname: web_server
services:
- service_name: WebServer
- node_hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -33,8 +33,7 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -65,10 +64,7 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
@@ -110,65 +106,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
# services:
# - service_name: backup_service
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -26,8 +26,7 @@ agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space:
type: UC2GreenObservation
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -65,10 +64,8 @@ agents:
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes: {}
observation_space: null
action_space:
action_list:
- type: DONOTHING
@@ -109,63 +106,73 @@ agents:
type: ProxyAgent
observation_space:
type: UC2BlueObservation
type: CUSTOM
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_hostname: domain_controller
services:
- service_name: domain_controller_dns_server
- node_hostname: web_server
services:
- service_name: web_server_database_client
- node_hostname: database_server
services:
- service_name: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_hostname: backup_server
- node_hostname: security_suite
- node_hostname: client_1
- node_hostname: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_hostname: router_1
ip_address_order:
- node_hostname: domain_controller
nic_num: 1
- node_hostname: web_server
nic_num: 1
- node_hostname: database_server
nic_num: 1
- node_hostname: backup_server
nic_num: 1
- node_hostname: security_suite
nic_num: 1
- node_hostname: client_1
nic_num: 1
- node_hostname: client_2
nic_num: 1
- node_hostname: security_suite
nic_num: 2
ics: null
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1___switch_1
- router_1___switch_2
- switch_1___domain_controller
- switch_1___web_server
- switch_1___database_server
- switch_1___backup_server
- switch_1___security_suite
- switch_2___client_1
- switch_2___client_2
- switch_2___security_suite
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:

View File

@@ -10,8 +10,7 @@ from _pytest.monkeypatch import MonkeyPatch
from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.observations.observations import ICSObservation
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.game import PrimaiteGame
from primaite.session.session import PrimaiteSession
@@ -533,7 +532,7 @@ def game_and_agent():
ip_address_list=["10.0.1.1", "10.0.1.2", "10.0.2.1", "10.0.2.2", "10.0.2.3"],
act_map={},
)
observation_space = ObservationManager(ICSObservation())
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
test_agent = ControlledAgent(

View File

@@ -13,7 +13,9 @@ MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
class TestPrimaiteSession:
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_creating_session(self, temp_primaite_session):
"""Check that creating a session from config works."""
@@ -56,6 +58,7 @@ class TestPrimaiteSession:
assert checkpoint_2.exists()
assert not checkpoint_3.exists()
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
def test_training_only_session(self, temp_primaite_session):
"""Check that you can run a training-only session."""
@@ -64,6 +67,7 @@ class TestPrimaiteSession:
session.start_session()
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
def test_eval_only_session(self, temp_primaite_session):
"""Check that you can load a model and run an eval-only session."""
@@ -72,6 +76,7 @@ class TestPrimaiteSession:
session.start_session()
# TODO: include checks that the model was loaded and that the eval-only session ran
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.skip(reason="Slow, reenable later")
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
def test_multi_agent_session(self, temp_primaite_session):
@@ -79,10 +84,12 @@ class TestPrimaiteSession:
with temp_primaite_session as session:
session.start_session()
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.skip(
reason="Currently software cannot be dynamically created/destroyed during simulation. Therefore, "
"reset doesn't implement software restore."

View File

@@ -1,6 +1,6 @@
import pytest
from primaite.game.agent.observations.observations import AclObservation
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
@@ -34,11 +34,13 @@ def test_acl_observations(simulation):
# add router acl rule
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1)
acl_obs = AclObservation(
acl_obs = ACLObservation(
where=["network", "nodes", router.hostname, "acl", "acl"],
node_ip_to_id={},
ports=["NTP", "HTTP", "POSTGRES_SERVER"],
protocols=["TCP", "UDP", "ICMP"],
ip_list=[],
port_list=["NTP", "HTTP", "POSTGRES_SERVER"],
protocol_list=["TCP", "UDP", "ICMP"],
num_rules=10,
wildcard_list=[],
)
observation_space = acl_obs.observe(simulation.describe_state())
@@ -46,11 +48,11 @@ def test_acl_observations(simulation):
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1)
assert rule_obs.get("permission") == 1 # permit = 1 deny = 2
assert rule_obs.get("source_node_id") == 1 # applies to all source nodes
assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes
assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs)
assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2
assert rule_obs.get("protocol") == 1 # 1 = No Protocol
assert rule_obs.get("source_ip_id") == 1 # applies to all source nodes
assert rule_obs.get("dest_ip_id") == 1 # applies to all destination nodes
assert rule_obs.get("source_port_id") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs)
assert rule_obs.get("dest_port_id") == 2 # NTP port is mapped to value 2
assert rule_obs.get("protocol_id") == 1 # 1 = No Protocol
router.acl.remove_rule(1)
@@ -59,8 +61,8 @@ def test_acl_observations(simulation):
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
assert rule_obs.get("position") == 0
assert rule_obs.get("permission") == 0
assert rule_obs.get("source_node_id") == 0
assert rule_obs.get("dest_node_id") == 0
assert rule_obs.get("source_port") == 0
assert rule_obs.get("dest_port") == 0
assert rule_obs.get("protocol") == 0
assert rule_obs.get("source_ip_id") == 0
assert rule_obs.get("dest_ip_id") == 0
assert rule_obs.get("source_port_id") == 0
assert rule_obs.get("dest_port_id") == 0
assert rule_obs.get("protocol_id") == 0

View File

@@ -23,7 +23,8 @@ def test_file_observation(simulation):
file = pc.file_system.create_file(file_name="dog.png")
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
)
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
@@ -49,7 +50,10 @@ def test_folder_observation(simulation):
file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder")
root_folder_obs = FolderObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"]
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"],
include_num_access=False,
num_files=1,
files=[],
)
assert root_folder_obs.space["health_status"] == spaces.Discrete(6)
@@ -68,3 +72,6 @@ def test_folder_observation(simulation):
observation_state = root_folder_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too
# TODO: Add tests to check num access is correct.

View File

@@ -0,0 +1,128 @@
from primaite.game.agent.observations.firewall_observation import FirewallObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.firewall import Firewall
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
def check_default_rules(acl_obs):
assert len(acl_obs) == 7
assert all(acl_obs[i]["position"] == i - 1 for i in range(1, 8))
assert all(acl_obs[i]["permission"] == 0 for i in range(1, 8))
assert all(acl_obs[i]["source_ip_id"] == 0 for i in range(1, 8))
assert all(acl_obs[i]["source_wildcard_id"] == 0 for i in range(1, 8))
assert all(acl_obs[i]["source_port_id"] == 0 for i in range(1, 8))
assert all(acl_obs[i]["dest_ip_id"] == 0 for i in range(1, 8))
assert all(acl_obs[i]["dest_wildcard_id"] == 0 for i in range(1, 8))
assert all(acl_obs[i]["dest_port_id"] == 0 for i in range(1, 8))
assert all(acl_obs[i]["protocol_id"] == 0 for i in range(1, 8))
def test_firewall_observation():
"""Test adding/removing acl rules and enabling/disabling ports."""
net = Network()
firewall = Firewall(hostname="firewall", operating_state=NodeOperatingState.ON)
firewall_observation = FirewallObservation(
where=[],
num_rules=7,
ip_list=["10.0.0.1", "10.0.0.2"],
wildcard_list=["0.0.0.255", "0.0.0.1"],
port_list=["HTTP", "DNS"],
protocol_list=["TCP"],
)
observation = firewall_observation.observe(firewall.describe_state())
assert "ACL" in observation
assert "PORTS" in observation
assert "INTERNAL" in observation["ACL"]
assert "EXTERNAL" in observation["ACL"]
assert "DMZ" in observation["ACL"]
assert "INBOUND" in observation["ACL"]["INTERNAL"]
assert "OUTBOUND" in observation["ACL"]["INTERNAL"]
assert "INBOUND" in observation["ACL"]["EXTERNAL"]
assert "OUTBOUND" in observation["ACL"]["EXTERNAL"]
assert "INBOUND" in observation["ACL"]["DMZ"]
assert "OUTBOUND" in observation["ACL"]["DMZ"]
all_acls = (
observation["ACL"]["INTERNAL"]["INBOUND"],
observation["ACL"]["INTERNAL"]["OUTBOUND"],
observation["ACL"]["EXTERNAL"]["INBOUND"],
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
observation["ACL"]["DMZ"]["INBOUND"],
observation["ACL"]["DMZ"]["OUTBOUND"],
)
for acl_obs in all_acls:
check_default_rules(acl_obs)
# add a rule to the internal inbound and check that the observation is correct
firewall.internal_inbound_acl.add_rule(
action=ACLAction.DENY,
protocol=IPProtocol.TCP,
src_ip_address="10.0.0.1",
src_wildcard_mask="0.0.0.1",
dst_ip_address="10.0.0.2",
dst_wildcard_mask="0.0.0.1",
src_port=Port.HTTP,
dst_port=Port.HTTP,
position=5,
)
observation = firewall_observation.observe(firewall.describe_state())
observed_rule = observation["ACL"]["INTERNAL"]["INBOUND"][5]
assert observed_rule["position"] == 4
assert observed_rule["permission"] == 2
assert observed_rule["source_ip_id"] == 2
assert observed_rule["source_wildcard_id"] == 3
assert observed_rule["source_port_id"] == 2
assert observed_rule["dest_ip_id"] == 3
assert observed_rule["dest_wildcard_id"] == 3
assert observed_rule["dest_port_id"] == 2
assert observed_rule["protocol_id"] == 2
# check that none of the other acls have changed
all_acls = (
observation["ACL"]["INTERNAL"]["OUTBOUND"],
observation["ACL"]["EXTERNAL"]["INBOUND"],
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
observation["ACL"]["DMZ"]["INBOUND"],
observation["ACL"]["DMZ"]["OUTBOUND"],
)
for acl_obs in all_acls:
check_default_rules(acl_obs)
# remove the rule and check that the observation is correct
firewall.internal_inbound_acl.remove_rule(5)
observation = firewall_observation.observe(firewall.describe_state())
all_acls = (
observation["ACL"]["INTERNAL"]["INBOUND"],
observation["ACL"]["INTERNAL"]["OUTBOUND"],
observation["ACL"]["EXTERNAL"]["INBOUND"],
observation["ACL"]["EXTERNAL"]["OUTBOUND"],
observation["ACL"]["DMZ"]["INBOUND"],
observation["ACL"]["DMZ"]["OUTBOUND"],
)
for acl_obs in all_acls:
check_default_rules(acl_obs)
# check that there are three ports in the observation
assert len(observation["PORTS"]) == 3
# check that the ports are all disabled
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4))
# connect a switch to the firewall and check that only the correct port is updated
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
link = net.connect(firewall.network_interface[1], switch.network_interface[1])
assert firewall.network_interface[1].enabled
observation = firewall_observation.observe(firewall.describe_state())
assert observation["PORTS"][1]["operating_status"] == 1
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(2, 4))
# disable the port and check that the operating status is updated
firewall.network_interface[1].disable()
assert not firewall.network_interface[1].enabled
observation = firewall_observation.observe(firewall.describe_state())
assert all(observation["PORTS"][i]["operating_status"] == 2 for i in range(1, 4))

View File

@@ -1,11 +1,13 @@
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.observations import LinkObservation
from primaite.game.agent.observations.link_observation import LinkObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import Link, Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.sim_container import Simulation
@@ -49,25 +51,44 @@ def simulation() -> Simulation:
return sim
def test_link_observation(simulation):
"""Test the link observation."""
# get a link
link: Link = next(iter(simulation.network.links.values()))
def test_link_observation():
"""Check the shape and contents of the link observation."""
net = Network()
sim = Simulation(network=net)
switch = Switch(hostname="switch", num_ports=5, operating_state=NodeOperatingState.ON)
computer_1 = Computer(
hostname="computer_1", ip_address="10.0.0.1", subnet_mask="255.255.255.0", start_up_duration=0
)
computer_2 = Computer(
hostname="computer_2", ip_address="10.0.0.2", subnet_mask="255.255.255.0", start_up_duration=0
)
computer_1.power_on()
computer_2.power_on()
link_1 = net.connect(switch.network_interface[1], computer_1.network_interface[1])
link_2 = net.connect(switch.network_interface[2], computer_2.network_interface[1])
assert link_1 is not None
assert link_2 is not None
computer: Computer = simulation.network.get_node_by_hostname("computer")
server: Server = simulation.network.get_node_by_hostname("server")
link_1_observation = LinkObservation(where=["network", "links", link_1.uuid])
link_2_observation = LinkObservation(where=["network", "links", link_2.uuid])
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
state = sim.describe_state()
link_1_obs = link_1_observation.observe(state)
link_2_obs = link_2_observation.observe(state)
assert "PROTOCOLS" in link_1_obs
assert "PROTOCOLS" in link_2_obs
assert "ALL" in link_1_obs["PROTOCOLS"]
assert "ALL" in link_2_obs["PROTOCOLS"]
assert link_1_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
assert link_2_observation.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11)
assert link_1_obs["PROTOCOLS"]["ALL"] == 0
assert link_2_obs["PROTOCOLS"]["ALL"] == 0
link_obs = LinkObservation(where=["network", "links", link.uuid])
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
observation_state = link_obs.observe(simulation.describe_state())
assert observation_state.get("PROTOCOLS") is not None
assert observation_state["PROTOCOLS"]["ALL"] == 0
computer.ping(server.network_interface.get(1).ip_address)
observation_state = link_obs.observe(simulation.describe_state())
assert observation_state["PROTOCOLS"]["ALL"] == 1
# Test that the link observation is updated when a packet is sent
computer_1.ping("10.0.0.2")
computer_2.ping("10.0.0.1")
state = sim.describe_state()
link_1_obs = link_1_observation.observe(state)
link_2_obs = link_2_observation.observe(state)
assert link_1_obs["PROTOCOLS"]["ALL"] > 0
assert link_2_obs["PROTOCOLS"]["ALL"] > 0

View File

@@ -5,7 +5,7 @@ import pytest
import yaml
from gymnasium import spaces
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -40,7 +40,7 @@ def test_nic(simulation):
nic: NIC = pc.network_interface[1]
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
@@ -61,17 +61,22 @@ def test_nic_categories(simulation):
"""Test the NIC observation nmne count categories."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
nic_obs = NICObservation(where=["network", "nodes", pc.hostname, "NICs", 1], include_nmne=True)
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
assert nic_obs.low_nmne_threshold == 0 # default
nic_obs = NicObservation(
@pytest.mark.skip(reason="Feature not implemented yet")
def test_config_nic_categories(simulation):
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=6,
high_nmne_threshold=9,
include_nmne=True,
)
assert nic_obs.high_nmne_threshold == 9
@@ -80,18 +85,20 @@ def test_nic_categories(simulation):
with pytest.raises(Exception):
# should throw an error
NicObservation(
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=9,
med_nmne_threshold=6,
high_nmne_threshold=9,
include_nmne=True,
)
with pytest.raises(Exception):
# should throw an error
NicObservation(
NICObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=9,
high_nmne_threshold=9,
include_nmne=True,
)

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.game.agent.observations.host_observations import HostObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.sim_container import Simulation
@@ -19,15 +19,28 @@ def simulation(example_network) -> Simulation:
return sim
def test_node_observation(simulation):
"""Test a Node observation."""
def test_host_observation(simulation):
"""Test a Host observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
node_obs = NodeObservation(where=["network", "nodes", pc.hostname])
host_obs = HostObservation(
where=["network", "nodes", pc.hostname],
num_applications=0,
num_files=1,
num_folders=1,
num_nics=2,
num_services=1,
include_num_access=False,
include_nmne=False,
services=[],
applications=[],
folders=[],
network_interfaces=[],
)
assert node_obs.space["operating_status"] == spaces.Discrete(5)
assert host_obs.space["operating_status"] == spaces.Discrete(5)
observation_state = node_obs.observe(simulation.describe_state())
observation_state = host_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 1 # computer is on
assert observation_state.get("SERVICES") is not None
@@ -36,11 +49,11 @@ def test_node_observation(simulation):
# turn off computer
pc.power_off()
observation_state = node_obs.observe(simulation.describe_state())
observation_state = host_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 4 # shutting down
for i in range(pc.shut_down_duration + 1):
pc.apply_timestep(i)
observation_state = node_obs.observe(simulation.describe_state())
observation_state = host_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 2

View File

@@ -0,0 +1,108 @@
from pprint import pprint
from primaite.game.agent.observations.acl_observation import ACLObservation
from primaite.game.agent.observations.nic_observations import PortObservation
from primaite.game.agent.observations.router_observation import RouterObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
def test_router_observation():
"""Test adding/removing acl rules and enabling/disabling ports."""
net = Network()
router = Router(hostname="router", num_ports=5, operating_state=NodeOperatingState.ON)
ports = [PortObservation(where=["NICs", i]) for i in range(1, 6)]
acl = ACLObservation(
where=["acl", "acl"],
num_rules=7,
ip_list=["10.0.0.1", "10.0.0.2"],
wildcard_list=["0.0.0.255", "0.0.0.1"],
port_list=["HTTP", "DNS"],
protocol_list=["TCP"],
)
router_observation = RouterObservation(where=[], ports=ports, num_ports=8, acl=acl)
# Observe the state using the RouterObservation instance
observed_output = router_observation.observe(router.describe_state())
# Check that the right number of ports and acls are in the router observation
assert len(observed_output["PORTS"]) == 8
assert len(observed_output["ACL"]) == 7
# Add an ACL rule to the router
router.acl.add_rule(
action=ACLAction.DENY,
protocol=IPProtocol.TCP,
src_ip_address="10.0.0.1",
src_wildcard_mask="0.0.0.1",
dst_ip_address="10.0.0.2",
dst_wildcard_mask="0.0.0.1",
src_port=Port.HTTP,
dst_port=Port.HTTP,
position=5,
)
# Observe the state using the RouterObservation instance
observed_output = router_observation.observe(router.describe_state())
observed_rule = observed_output["ACL"][5]
assert observed_rule["position"] == 4
assert observed_rule["permission"] == 2
assert observed_rule["source_ip_id"] == 2
assert observed_rule["source_wildcard_id"] == 3
assert observed_rule["source_port_id"] == 2
assert observed_rule["dest_ip_id"] == 3
assert observed_rule["dest_wildcard_id"] == 3
assert observed_rule["dest_port_id"] == 2
assert observed_rule["protocol_id"] == 2
# Add an ACL rule with ALL/NONE values and check that the observation is correct
router.acl.add_rule(
action=ACLAction.PERMIT,
protocol=None,
src_ip_address=None,
src_wildcard_mask=None,
dst_ip_address=None,
dst_wildcard_mask=None,
src_port=None,
dst_port=None,
position=2,
)
observed_output = router_observation.observe(router.describe_state())
observed_rule = observed_output["ACL"][2]
assert observed_rule["position"] == 1
assert observed_rule["permission"] == 1
assert observed_rule["source_ip_id"] == 1
assert observed_rule["source_wildcard_id"] == 1
assert observed_rule["source_port_id"] == 1
assert observed_rule["dest_ip_id"] == 1
assert observed_rule["dest_wildcard_id"] == 1
assert observed_rule["dest_port_id"] == 1
assert observed_rule["protocol_id"] == 1
# Check that the router ports are all disabled
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6))
# connect a switch to the router and check that only the correct port is updated
switch = Switch(hostname="switch", num_ports=1, operating_state=NodeOperatingState.ON)
link = net.connect(router.network_interface[1], switch.network_interface[1])
assert router.network_interface[1].enabled
observed_output = router_observation.observe(router.describe_state())
assert observed_output["PORTS"][1]["operating_status"] == 1
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(2, 6))
# disable the port and check that the operating status is updated
router.network_interface[1].disable()
assert not router.network_interface[1].enabled
observed_output = router_observation.observe(router.describe_state())
assert all(observed_output["PORTS"][i]["operating_status"] == 2 for i in range(1, 6))
# Check that ports that are out of range are shown as unused
observed_output = router_observation.observe(router.describe_state())
assert observed_output["PORTS"][6]["operating_status"] == 0
assert observed_output["PORTS"][7]["operating_status"] == 0
assert observed_output["PORTS"][8]["operating_status"] == 0

View File

@@ -14,7 +14,13 @@ def test_file_observation():
state = sim.describe_state()
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"],
include_num_access=False,
)
assert dog_file_obs.observe(state) == {"health_status": 1}
assert dog_file_obs.space == spaces.Dict({"health_status": spaces.Discrete(6)})
# TODO:
# def test_file_num_access():
# ...

View File

@@ -1,4 +1,4 @@
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.nic_observations import NICObservation
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.sim_container import Simulation
@@ -141,9 +141,9 @@ def test_describe_state_nmne(uc2_network):
def test_capture_nmne_observations(uc2_network):
"""
Tests the NicObservation class's functionality within a simulated network environment.
Tests the NICObservation class's functionality within a simulated network environment.
This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the
This test ensures the observation space, as defined by instances of NICObservation, accurately reflects the
number of MNEs detected based on network activities over multiple iterations.
The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update
@@ -168,8 +168,8 @@ def test_capture_nmne_observations(uc2_network):
set_nmne_config(nmne_config)
# Define observations for the NICs of the database and web servers
db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1])
web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1])
db_server_nic_obs = NICObservation(where=["network", "nodes", "database_server", "NICs", 1], include_nmne=True)
web_server_nic_obs = NICObservation(where=["network", "nodes", "web_server", "NICs", 1], include_nmne=True)
# Iterate through a set of test cases to simulate multiple DELETE queries
for i in range(0, 20):

View File

@@ -1,6 +1,5 @@
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.observations.observations import ICSObservation
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
@@ -52,7 +51,7 @@ def test_probabilistic_agent():
2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}},
},
)
observation_space = ObservationManager(ICSObservation())
observation_space = ObservationManager(NestedObservation(components={}))
reward_function = RewardFunction()
pa = ProbabilisticAgent(