Merged PR 301: #2350: Confirm action / observation space conforms to CAOS v0.7

## Summary
### **work related to v0.7 CAOS**
- Split observations.py into:
  - agent_observations.py
  - file_system_observations.py
  - node_observations.py
  - nic_observations.py
  - observation_manager.py
  - observations.py
  - software_observations.py
- added tests to ensure that the observations align with [QTSL-820-2450 - ARCD Track 2 Common Action Observation Space Definition v0.7](https://nscuk.sharepoint.com//r/sites/SSE32ARCDIDT/Shared%20Documents/General/ARCD/Architecture%20%26%20Design%20Documentation/Common/CAOS%20Related%20Documents/QTSL-820-2450%20-%20ARCD%20Track%202%20Common%20Action%20Observation%20Space%20Definition%20v0.7.xlsx?d=wee5713d8640b4b5bb3cb5624936e417e&csf=1&web=1&e=lByVQ5)

### preparation for v0.8 CAOS
WILL NOT AFFECT OBSERVATION SPACE FOR V0.7

**DO NOT PANIC**

these features are needed for v0.8

- integrated `num_access` to file (not used yet in file observations)
- integrated `num_file_deletions` and `num_file_creations` to file_system (not used yet in node observations)

## Test process
*How have you tested this (if applicable)?*

## Checklist
- [X] PR is linked to a **work item**
- [X] **acceptance criteria** of linked ticket are met
- [X] performed **self-review** of the code
- [X] written **tests** for any new functionality added with this PR
- [ ] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [ ] updated the **change log**
- [X] ran **pre-commit** checks for code style
- [ ] attended to any **TO-DOs** left in the code

Related work items: #2350
This commit is contained in:
Czar Echavez
2024-03-13 09:20:23 +00:00
39 changed files with 1921 additions and 1125 deletions

View File

@@ -23,6 +23,11 @@ This section defines high-level settings that apply across the game, currently i
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
``max_episode_length``
----------------------
@@ -44,3 +49,8 @@ See :ref:`List of Ports <List of Ports>` for a list of ports.
A list of protocols that the Reinforcement Learning agent(s) are able to see in the observation space.
See :ref:`List of IPProtocols <List of IPProtocols>` for a list of protocols.
``thresholds``
--------------
These are used to determine the thresholds of high, medium and low categories for counted observation occurrences.

View File

@@ -22,14 +22,17 @@ io_settings:
game:
max_episode_length: 256
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- ref: client_2_green_user

View File

@@ -6,7 +6,7 @@ from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
if TYPE_CHECKING:
@@ -146,23 +146,10 @@ class AbstractAgent(ABC):
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""
pass
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Sample the action space randomly.
:param obs: Current observation for this agent, not used in RandomAgent
:type obs: ObsType
:param timestep: The current simulation timestep, not used in RandomAgent
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.action_manager.space.sample())
"""Return an action to be taken in the environment."""
return super().get_action(obs=obs, timestep=timestep)
class ProxyAgent(AbstractAgent):

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,188 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.game.agent.observations.observations import (
AbstractObservation,
AclObservation,
ICSObservation,
LinkObservation,
NullObservation,
)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class UC2BlueObservation(AbstractObservation):
"""Container for all observations used by the blue agent in UC2.
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
for the purpose of compiling several observation components.
"""
def __init__(
self,
nodes: List[NodeObservation],
links: List[LinkObservation],
acl: AclObservation,
ics: ICSObservation,
where: Optional[List[str]] = None,
) -> None:
"""Initialise UC2 blue observation.
:param nodes: List of node observations
:type nodes: List[NodeObservation]
:param links: List of link observations
:type links: List[LinkObservation]
:param acl: The Access Control List observation
:type acl: AclObservation
:param ics: The ICS observation
:type ics: ICSObservation
:param where: Where in the simulation state dict to find information. Not used in this particular observation
because it only compiles other observations and doesn't contribute any new information, defaults to None
:type where: Optional[List[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.nodes: List[NodeObservation] = nodes
self.links: List[LinkObservation] = links
self.acl: AclObservation = acl
self.ics: ICSObservation = ics
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
"ACL": self.acl.default_observation,
"ICS": self.ics.default_observation,
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
obs["ACL"] = self.acl.observe(state)
obs["ICS"] = self.ics.observe(state)
return obs
@property
def space(self) -> spaces.Space:
"""
Gymnasium space object describing the observation space shape.
:return: Space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
"ACL": self.acl.space,
"ICS": self.ics.space,
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
"""Create UC2 blue observation from a config.
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
links, ACL and ICS observations.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed UC2 blue observation
:rtype: UC2BlueObservation
"""
node_configs = config["nodes"]
num_services_per_node = config["num_services_per_node"]
num_folders_per_node = config["num_folders_per_node"]
num_files_per_folder = config["num_files_per_folder"]
num_nics_per_node = config["num_nics_per_node"]
nodes = [
NodeObservation.from_config(
config=n,
game=game,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)
for n in node_configs
]
link_configs = config["links"]
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
acl_config = config["acl"]
acl = AclObservation.from_config(config=acl_config, game=game)
ics_config = config["ics"]
ics = ICSObservation.from_config(config=ics_config, game=game)
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
return new
class UC2RedObservation(AbstractObservation):
"""Container for all observations used by the red agent in UC2."""
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
super().__init__()
self.where: Optional[List[str]] = where
self.nodes: List[NodeObservation] = nodes
self.default_observation: Dict = {
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
if self.where is None:
return self.default_observation
obs = {}
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
"""
Create UC2 red observation from a config.
:param config: Dictionary containing the configuration for this UC2 red observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
node_configs = config["nodes"]
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
return cls(nodes=nodes, where=["network"])
class UC2GreenObservation(NullObservation):
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
pass

View File

@@ -0,0 +1,177 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite import getLogger
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class FileObservation(AbstractObservation):
"""Observation of a file on a node in the network."""
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""
Initialise file observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a file looks like this:
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.default_observation: spaces.Space = {"health_status": 0}
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
file_state = access_from_nested_dict(state, self.where)
if file_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {"health_status": file_state["visible_status"]}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"health_status": spaces.Discrete(6)})
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
"""Create file observation from a config.
:param config: Dictionary containing the configuration for this file observation.
:type config: Dict
:param game: _description_
:type game: PrimaiteGame
:param parent_where: _description_, defaults to None
:type parent_where: _type_, optional
:return: _description_
:rtype: _type_
"""
return cls(where=parent_where + ["files", config["file_name"]])
class FolderObservation(AbstractObservation):
"""Folder observation, including files inside of the folder."""
def __init__(
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
) -> None:
"""Initialise folder Observation, including files inside the folder.
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
A typical location for a file looks like this:
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
:type where: Optional[List[str]]
:param max_files: As size of the space must remain static, define max files that can be in this folder
, defaults to 5
:type max_files: int, optional
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
that even if new files are created, the existing files will always occupy the same space in the observation
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
name, it will take the position defined in this dict. Defaults to {}
:type file_positions: Dict[int, str], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.files: List[FileObservation] = files
while len(self.files) < num_files_per_folder:
self.files.append(FileObservation())
while len(self.files) > num_files_per_folder:
truncated_file = self.files.pop()
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
_LOGGER.warning(msg)
self.default_observation = {
"health_status": 0,
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
folder_state = access_from_nested_dict(state, self.where)
if folder_state is NOT_PRESENT_IN_STATE:
return self.default_observation
health_status = folder_state["health_status"]
obs = {}
obs["health_status"] = health_status
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
"health_status": spaces.Discrete(6),
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
}
)
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
) -> "FolderObservation":
"""Create folder observation from a config. Also creates child file observations.
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
folder and the files inside of it.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
parent node. A typical location for a node ``where`` can be:
['network','nodes',<node_hostname>,'file_system']
:type parent_where: Optional[List[str]]
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed folder observation
:rtype: FolderObservation
"""
where = parent_where + ["folders", config["folder_name"]]
file_configs = config["files"]
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)

View File

@@ -0,0 +1,188 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
from primaite.simulator.network.nmne import CAPTURE_NMNE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NicObservation(AbstractObservation):
"""Observation of a Network Interface Card (NIC) in the network."""
low_nmne_threshold: int = 0
"""The minimum number of malicious network events to be considered low."""
med_nmne_threshold: int = 5
"""The minimum number of malicious network events to be considered medium."""
high_nmne_threshold: int = 10
"""The minimum number of malicious network events to be considered high."""
global CAPTURE_NMNE
@property
def default_observation(self) -> Dict:
"""The default NIC observation dict."""
data = {"nic_status": 0}
if CAPTURE_NMNE:
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
return data
def __init__(
self,
where: Optional[Tuple[str]] = None,
low_nmne_threshold: Optional[int] = 0,
med_nmne_threshold: Optional[int] = 5,
high_nmne_threshold: Optional[int] = 10,
) -> None:
"""Initialise NIC observation.
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
example may look like this:
['network','nodes',<node_hostname>,'NICs',<nic_number>]
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
:type where: Optional[Tuple[str]], optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
global CAPTURE_NMNE
if CAPTURE_NMNE:
self.nmne_inbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
self.nmne_outbound_last_step: int = 0
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
us find the difference."""
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
self._validate_nmne_categories(
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)
def _validate_nmne_categories(
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
):
"""
Validates the nmne threshold config.
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
"""
if high_nmne_threshold <= med_nmne_threshold:
raise Exception(
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
f"than medium nmne count ({med_nmne_threshold})"
)
if med_nmne_threshold <= low_nmne_threshold:
raise Exception(
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
f"than low nmne count ({low_nmne_threshold})"
)
self.high_nmne_threshold = high_nmne_threshold
self.med_nmne_threshold = med_nmne_threshold
self.low_nmne_threshold = low_nmne_threshold
def _categorise_mne_count(self, nmne_count: int) -> int:
"""
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
Bins are defined as follows:
- 0: No MNEs detected (0 events).
- 1: Low number of MNEs (default 1-5 events).
- 2: Moderate number of MNEs (default 6-10 events).
- 3: High number of MNEs (default more than 10 events).
:param nmne_count: Number of MNEs detected.
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
"""
if nmne_count > self.high_nmne_threshold:
return 3
elif nmne_count > self.med_nmne_threshold:
return 2
elif nmne_count > self.low_nmne_threshold:
return 1
return 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
nic_state = access_from_nested_dict(state, self.where)
if nic_state is NOT_PRESENT_IN_STATE:
return self.default_observation
else:
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
if CAPTURE_NMNE:
obs_dict.update({"NMNE": {}})
direction_dict = nic_state["nmne"].get("direction", {})
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
inbound_count = inbound_keywords.get("*", 0)
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
outbound_count = outbound_keywords.get("*", 0)
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
self.nmne_inbound_last_step = inbound_count
self.nmne_outbound_last_step = outbound_count
return obs_dict
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
if CAPTURE_NMNE:
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
return space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
"""Create NIC observation from a config.
:param config: Dictionary containing the configuration for this NIC observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
:type parent_where: Optional[List[str]]
:return: Constructed NIC observation
:rtype: NicObservation
"""
low_nmne_threshold = None
med_nmne_threshold = None
high_nmne_threshold = None
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
threshold = game.options.thresholds["nmne"]
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
return cls(
where=parent_where + ["NICs", config["nic_num"]],
low_nmne_threshold=low_nmne_threshold,
med_nmne_threshold=med_nmne_threshold,
high_nmne_threshold=high_nmne_threshold,
)

View File

@@ -0,0 +1,200 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite import getLogger
from primaite.game.agent.observations.file_system_observations import FolderObservation
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.observations.software_observation import ServiceObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class NodeObservation(AbstractObservation):
"""Observation of a node in the network. Includes services, folders and NICs."""
def __init__(
self,
where: Optional[Tuple[str]] = None,
services: List[ServiceObservation] = [],
folders: List[FolderObservation] = [],
network_interfaces: List[NicObservation] = [],
logon_status: bool = False,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> None:
"""
Configurable observation for a node in the simulation.
:param where: Where in the simulation state dictionary for find relevant information for this observation.
A typical location for a node looks like this:
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
:type where: List[str], optional
:param services: Mapping between position in observation space and service name, defaults to {}
:type services: Dict[int,str], optional
:param max_services: Max number of services that can be presented in observation space for this node
, defaults to 2
:type max_services: int, optional
:param folders: Mapping between position in observation space and folder name, defaults to {}
:type folders: Dict[int,str], optional
:param max_folders: Max number of folders in this node's obs space, defaults to 2
:type max_folders: int, optional
:param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {}
:type network_interfaces: Dict[int,str], optional
:param max_nics: Max number of network interfaces in this node's obs space, defaults to 5
:type max_nics: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.services: List[ServiceObservation] = services
while len(self.services) < num_services_per_node:
# add empty service observation without `where` parameter so it always returns default (blank) observation
self.services.append(ServiceObservation())
while len(self.services) > num_services_per_node:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warning(msg)
# truncate service list
self.folders: List[FolderObservation] = folders
# add empty folder observation without `where` parameter that will always return default (blank) observations
while len(self.folders) < num_folders_per_node:
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warning(msg)
self.network_interfaces: List[NicObservation] = network_interfaces
while len(self.network_interfaces) < num_nics_per_node:
self.network_interfaces.append(NicObservation())
while len(self.network_interfaces) > num_nics_per_node:
truncated_nic = self.network_interfaces.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warning(msg)
self.logon_status: bool = logon_status
self.default_observation: Dict = {
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
"operating_status": 0,
}
if self.logon_status:
self.default_observation["logon_status"] = 0
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
node_state = access_from_nested_dict(state, self.where)
if node_state is NOT_PRESENT_IN_STATE:
return self.default_observation
obs = {}
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
obs["operating_status"] = node_state["operating_state"]
obs["NICS"] = {
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
}
if self.logon_status:
obs["logon_status"] = 0
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
space_shape = {
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
"operating_status": spaces.Discrete(5),
"NICS": spaces.Dict(
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
),
}
if self.logon_status:
space_shape["logon_status"] = spaces.Discrete(3)
return spaces.Dict(space_shape)
@classmethod
def from_config(
cls,
config: Dict,
game: "PrimaiteGame",
parent_where: Optional[List[str]] = None,
num_services_per_node: int = 2,
num_folders_per_node: int = 2,
num_files_per_folder: int = 2,
num_nics_per_node: int = 2,
) -> "NodeObservation":
"""Create node observation from a config. Also creates child service, folder and NIC observations.
:param config: Dictionary containing the configuration for this node observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
network. A typical location for it would be: ['network',]
:type parent_where: Optional[List[str]]
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
observation size) , defaults to 2
:type num_services_per_node: int, optional
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
observation size) , defaults to 2
:type num_folders_per_node: int, optional
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
observation size) , defaults to 2
:type num_files_per_folder: int, optional
:return: Constructed node observation
:rtype: NodeObservation
"""
node_hostname = config["node_hostname"]
if parent_where is None:
where = ["network", "nodes", node_hostname]
else:
where = parent_where + ["nodes", node_hostname]
svc_configs = config.get("services", {})
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
folder_configs = config.get("folders", {})
folders = [
FolderObservation.from_config(
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
)
for c in folder_configs
]
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
logon_status = config.get("logon_status", False)
return cls(
where=where,
services=services,
folders=folders,
network_interfaces=network_interfaces,
logon_status=logon_status,
num_services_per_node=num_services_per_node,
num_folders_per_node=num_folders_per_node,
num_files_per_folder=num_files_per_folder,
num_nics_per_node=num_nics_per_node,
)

View File

@@ -0,0 +1,73 @@
from typing import Dict, TYPE_CHECKING
from gymnasium.core import ObsType
from primaite.game.agent.observations.agent_observations import (
UC2BlueObservation,
UC2GreenObservation,
UC2RedObservation,
)
from primaite.game.agent.observations.observations import AbstractObservation
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class ObservationManager:
"""
Manage the observations of an Agent.
The observation space has the purpose of:
1. Reading the outputted state from the PrimAITE Simulation.
2. Selecting parts of the simulation state that are requested by the simulation config
3. Formatting this information so an agent can use it to make decisions.
"""
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
# refactor.
def __init__(self, observation: AbstractObservation) -> None:
"""Initialise observation space.
:param observation: Observation object
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.current_observation: ObsType
def update(self, state: Dict) -> Dict:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
"""
self.current_observation = self.obs.observe(state)
return self.current_observation
@property
def space(self) -> None:
"""Gymnasium space object describing the observation space shape."""
return self.obs.space
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.
It should contain the key 'type' which selects which observation class to use (from a choice of:
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
The other key is 'options' which are passed to the constructor of the selected observation class.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
"""
if config["type"] == "UC2BlueObservation":
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2RedObservation":
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
elif config["type"] == "UC2GreenObservation":
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
else:
raise ValueError("Observation space type invalid")

View File

@@ -0,0 +1,309 @@
"""Manages the observation space for the agent."""
from abc import ABC, abstractmethod
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
_LOGGER = getLogger(__name__)
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class AbstractObservation(ABC):
"""Abstract class for an observation space component."""
@abstractmethod
def observe(self, state: Dict) -> Any:
"""
Return an observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Any
"""
pass
@property
@abstractmethod
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space."""
pass
@classmethod
@abstractmethod
def from_config(cls, config: Dict, game: "PrimaiteGame"):
"""Create this observation space component form a serialised format.
The `game` parameter is for a the PrimaiteGame object that spawns this component.
"""
pass
class LinkObservation(AbstractObservation):
"""Observation of a link in the network."""
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
"Default observation is what should be returned when the link doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise link observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'servics', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
link_state = access_from_nested_dict(state, self.where)
if link_state is NOT_PRESENT_IN_STATE:
return self.default_observation
bandwidth = link_state["bandwidth"]
load = link_state["current_load"]
if load == 0:
utilisation_category = 0
else:
utilisation_fraction = load / bandwidth
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
utilisation_category = int(utilisation_fraction * 9) + 1
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
"""Create link observation from a config.
:param config: Dictionary containing the configuration for this link observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Constructed link observation
:rtype: LinkObservation
"""
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
class AclObservation(AbstractObservation):
"""Observation of an Access Control List (ACL) in the network."""
# TODO: should where be optional, and we can use where=None to pad the observation space?
# definitely the current approach does not support tracking files that aren't specified by name, for example
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
# this needs adding, but not for the MVP.
def __init__(
self,
node_ip_to_id: Dict[str, int],
ports: List[int],
protocols: List[str],
where: Optional[Tuple[str]] = None,
num_rules: int = 10,
) -> None:
"""Initialise ACL observation.
:param node_ip_to_id: Mapping between IP address and ID.
:type node_ip_to_id: Dict[str, int]
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
:type ports: List[int]
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
:type protocols: list[str]
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
example may look like this:
['network','nodes',<router_hostname>,'acl','acl']
:type where: Optional[Tuple[str]], optional
:param num_rules: , defaults to 10
:type num_rules: int, optional
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
self.num_rules: int = num_rules
self.node_to_id: Dict[str, int] = node_ip_to_id
"List of node IP addresses, order in this list determines how they are converted to an ID"
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
"List of ports which are part of the game that define the ordering when converting to an ID"
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
"List of protocols which are part of the game, defines ordering when converting to an ID"
self.default_observation: Dict = {
i
+ 1: {
"position": i,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
for i in range(self.num_rules)
}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
acl_state: Dict = access_from_nested_dict(state, self.where)
if acl_state is NOT_PRESENT_IN_STATE:
return self.default_observation
# TODO: what if the ACL has more rules than num of max rules for obs space
obs = {}
acl_items = dict(acl_state.items())
i = 1 # don't show rule 0 for compatibility reasons.
while i < self.num_rules + 1:
rule_state = acl_items[i]
if rule_state is None:
obs[i] = {
"position": i - 1,
"permission": 0,
"source_node_id": 0,
"source_port": 0,
"dest_node_id": 0,
"dest_port": 0,
"protocol": 0,
}
else:
src_ip = rule_state["src_ip_address"]
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
dst_ip = rule_state["dst_ip_address"]
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
src_port = rule_state["src_port"]
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
dst_port = rule_state["dst_port"]
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
protocol = rule_state["protocol"]
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
obs[i] = {
"position": i - 1,
"permission": rule_state["action"],
"source_node_id": src_node_id,
"source_port": src_port_id,
"dest_node_id": dst_node_ip,
"dest_port": dst_port_id,
"protocol": protocol_id,
}
i += 1
return obs
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape.
:return: Gymnasium space
:rtype: spaces.Space
"""
return spaces.Dict(
{
i
+ 1: spaces.Dict(
{
"position": spaces.Discrete(self.num_rules),
"permission": spaces.Discrete(3),
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
}
)
for i in range(self.num_rules)
}
)
@classmethod
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
"""Generate ACL observation from a config.
:param config: Dictionary containing the configuration for this ACL observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:return: Observation object
:rtype: AclObservation
"""
max_acl_rules = config["options"]["max_acl_rules"]
node_ip_to_idx = {}
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
node_ref = ip_map_config["node_hostname"]
nic_num = ip_map_config["nic_num"]
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
nic_obj = node_obj.network_interface[nic_num]
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
router_hostname = config["router_hostname"]
return cls(
node_ip_to_id=node_ip_to_idx,
ports=game.options.ports,
protocols=game.options.protocols,
where=["network", "nodes", router_hostname, "acl", "acl"],
num_rules=max_acl_rules,
)
class NullObservation(AbstractObservation):
"""Null observation, returns a single 0 value for the observation space."""
def __init__(self, where: Optional[List[str]] = None):
"""Initialise null observation."""
self.default_observation: Dict = {}
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation."""
return 0
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Discrete(1)
@classmethod
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
"""
Create null observation from a config.
The parameters are ignored, they are here to match the signature of the other observation classes.
"""
return cls()
class ICSObservation(NullObservation):
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
pass

View File

@@ -0,0 +1,163 @@
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite.game.agent.observations.observations import AbstractObservation
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.game import PrimaiteGame
class ServiceObservation(AbstractObservation):
"""Observation of a service in the network."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
"Default observation is what should be returned when the service doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise service observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'services', <service_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
service_state = access_from_nested_dict(state, self.where)
if service_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {
"operating_status": service_state["operating_state"],
"health_status": service_state["health_state_visible"],
}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ServiceObservation":
"""Create service observation from a config.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", config["service_name"]])
class ApplicationObservation(AbstractObservation):
"""Observation of an application in the network."""
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0}
"Default observation is what should be returned when the application doesn't exist."
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
"""Initialise application observation.
:param where: Store information about where in the simulation state dictionary to find the relevant information.
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
zeroes.
A typical location for a service looks like this:
`['network','nodes',<node_hostname>,'applications', <application_name>]`
:type where: Optional[List[str]]
"""
super().__init__()
self.where: Optional[Tuple[str]] = where
def observe(self, state: Dict) -> Dict:
"""Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
:return: Observation
:rtype: Dict
"""
if self.where is None:
return self.default_observation
app_state = access_from_nested_dict(state, self.where)
if app_state is NOT_PRESENT_IN_STATE:
return self.default_observation
return {
"operating_status": app_state["operating_state"],
"health_status": app_state["health_state_visible"],
"num_executions": self._categorise_num_executions(app_state["num_executions"]),
}
@property
def space(self) -> spaces.Space:
"""Gymnasium space object describing the observation space shape."""
return spaces.Dict(
{
"operating_status": spaces.Discrete(7),
"health_status": spaces.Discrete(6),
"num_executions": spaces.Discrete(4),
}
)
@classmethod
def from_config(
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
) -> "ApplicationObservation":
"""Create application observation from a config.
:param config: Dictionary containing the configuration for this service observation.
:type config: Dict
:param game: Reference to the PrimaiteGame object that spawned this observation.
:type game: PrimaiteGame
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
:type parent_where: Optional[List[str]], optional
:return: Constructed service observation
:rtype: ApplicationObservation
"""
return cls(where=parent_where + ["services", config["application_name"]])
@classmethod
def _categorise_num_executions(cls, num_executions: int) -> int:
"""
Categorise the number of executions of an application.
Helps classify the number of application executions into different categories.
Current categories:
- 0: Application is never executed
- 1: Application is executed a low number of times (1-5)
- 2: Application is executed often (6-10)
- 3: Application is executed a high number of times (more than 10)
:param: num_executions: Number of times the application is executed
"""
if num_executions > 10:
return 3
elif num_executions > 5:
return 2
elif num_executions > 0:
return 1
return 0

View File

@@ -7,7 +7,7 @@ from gymnasium.core import ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction

View File

@@ -0,0 +1,21 @@
from typing import Dict, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""Sample the action space randomly.
:param obs: Current observation for this agent, not used in RandomAgent
:type obs: ObsType
:param timestep: The current simulation timestep, not used in RandomAgent
:type timestep: int
:return: Action formatted in CAOS format
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.action_manager.space.sample())

View File

@@ -1,16 +1,16 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from ipaddress import IPv4Address
from typing import Dict, List, Tuple
from typing import Dict, List, Optional, Tuple
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import ProbabilisticAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
@@ -67,8 +67,13 @@ class PrimaiteGameOptions(BaseModel):
model_config = ConfigDict(extra="forbid")
max_episode_length: int = 256
"""Maximum number of episodes for the PrimAITE game."""
ports: List[str]
"""A whitelist of available ports in the simulation."""
protocols: List[str]
"""A whitelist of available protocols in the simulation."""
thresholds: Optional[Dict] = {}
"""A dict containing the thresholds used for determining what is acceptable during observations."""
class PrimaiteGame:

View File

@@ -38,6 +38,8 @@ class File(FileSystemItemABC):
"The Path if real is True."
sim_root: Optional[Path] = None
"Root path of the simulation."
num_access: int = 0
"Number of times the file was accessed in the current step."
def __init__(self, **kwargs):
"""
@@ -93,11 +95,23 @@ class File(FileSystemItemABC):
return os.path.getsize(self.sim_path)
return self.sim_size
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the file.
:param timestep: The current timestep of the simulation.
"""
super().apply_timestep(timestep=timestep)
# reset the number of accesses to 0
self.num_access = 0
def describe_state(self) -> Dict:
"""Produce a dictionary describing the current state of this object."""
state = super().describe_state()
state["size"] = self.size
state["file_type"] = self.file_type.name
state["num_access"] = self.num_access
return state
def scan(self) -> bool:
@@ -106,6 +120,7 @@ class File(FileSystemItemABC):
self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}")
return False
self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}")
self.visible_health_status = self.health_status
@@ -162,6 +177,7 @@ class File(FileSystemItemABC):
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
self.health_status = FileSystemItemHealthStatus.GOOD
self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}")
return True
@@ -176,6 +192,7 @@ class File(FileSystemItemABC):
if self.health_status == FileSystemItemHealthStatus.GOOD:
self.health_status = FileSystemItemHealthStatus.CORRUPT
self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}")
return True
@@ -189,6 +206,7 @@ class File(FileSystemItemABC):
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
self.health_status = FileSystemItemHealthStatus.GOOD
self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name
self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}")
return True
@@ -199,6 +217,7 @@ class File(FileSystemItemABC):
self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}")
return False
self.num_access += 1 # file was accessed
self.deleted = True
self.sys_log.info(f"File deleted {self.folder_name}/{self.name}")
return True

View File

@@ -28,6 +28,10 @@ class FileSystem(SimComponent):
"Instance of SysLog used to create system logs."
sim_root: Path
"Root path of the simulation."
num_file_creations: int = 0
"Number of file creations in the current step."
num_file_deletions: int = 0
"Number of file deletions in the current step."
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -264,6 +268,8 @@ class FileSystem(SimComponent):
)
folder.add_file(file)
self._file_request_manager.add_request(name=file.name, request_type=RequestType(func=file._request_manager))
# increment file creation
self.num_file_creations += 1
return file
def get_file(self, folder_name: str, file_name: str, include_deleted: Optional[bool] = False) -> Optional[File]:
@@ -324,6 +330,8 @@ class FileSystem(SimComponent):
if folder:
file = folder.get_file(file_name)
if file:
# increment file creation
self.num_file_deletions += 1
folder.remove_file(file)
return True
return False
@@ -355,15 +363,14 @@ class FileSystem(SimComponent):
"""
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
if file:
src_folder = file.folder
# remove file from src
src_folder.remove_file(file)
self.delete_file(folder_name=file.folder_name, file_name=file.name)
dst_folder = self.get_folder(folder_name=dst_folder_name)
if not dst_folder:
dst_folder = self.create_folder(dst_folder_name)
# add file to dst
dst_folder.add_file(file)
self.num_file_creations += 1
if file.real:
old_sim_path = file.sim_path
file.sim_path = file.sim_root / file.path
@@ -391,6 +398,10 @@ class FileSystem(SimComponent):
folder_name=dst_folder.name,
**file.model_dump(exclude={"uuid", "folder_id", "folder_name", "sim_path"}),
)
self.num_file_creations += 1
# increment access counter
file.num_access += 1
dst_folder.add_file(file_copy, force=True)
if file.real:
@@ -408,12 +419,20 @@ class FileSystem(SimComponent):
state = super().describe_state()
state["folders"] = {folder.name: folder.describe_state() for folder in self.folders.values()}
state["deleted_folders"] = {folder.name: folder.describe_state() for folder in self.deleted_folders.values()}
state["num_file_creations"] = self.num_file_creations
state["num_file_deletions"] = self.num_file_deletions
return state
def apply_timestep(self, timestep: int) -> None:
"""Apply time step to FileSystem and its child folders and files."""
super().apply_timestep(timestep=timestep)
# reset number of file creations
self.num_file_creations = 0
# reset number of file deletions
self.num_file_deletions = 0
# apply timestep to folders
for folder_id in self.folders:
self.folders[folder_id].apply_timestep(timestep=timestep)

View File

@@ -138,7 +138,8 @@ class Folder(FileSystemItemABC):
file = self.get_file_by_id(file_uuid=file_id)
file.scan()
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
self.visible_health_status = FileSystemItemHealthStatus.CORRUPT
self.health_status = FileSystemItemHealthStatus.CORRUPT
self.visible_health_status = self.health_status
def _reveal_to_red_timestep(self) -> None:
"""Apply reveal to red timestep."""

View File

@@ -59,6 +59,16 @@ class Application(IOSoftware):
)
return state
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the application.
:param timestep: The current timestep of the simulation.
"""
super().apply_timestep(timestep=timestep)
self.num_executions = 0 # reset number of executions
def _can_perform_action(self) -> bool:
"""
Checks if the application can perform actions.

View File

@@ -48,6 +48,7 @@ class DatabaseClient(Application):
def execute(self) -> bool:
"""Execution definition for db client: perform a select query."""
self.num_executions += 1 # trying to connect counts as an execution
if self.connections:
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
else:

View File

@@ -193,6 +193,8 @@ class DataManipulationBot(Application):
if not self._can_perform_action():
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
self.run()
self.num_executions += 1
return self._application_loop()
def _application_loop(self) -> bool:

View File

@@ -89,6 +89,8 @@ class WebBrowser(Application):
if not self._can_perform_action():
return False
self.num_executions += 1 # trying to connect counts as an execution
# reset latest response
self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND)

View File

@@ -10,7 +10,8 @@ from _pytest.monkeypatch import MonkeyPatch
from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.observations import ICSObservation, ObservationManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.observations.observations import ICSObservation
from primaite.game.agent.rewards import RewardFunction
from primaite.game.game import PrimaiteGame
from primaite.session.session import PrimaiteSession

View File

@@ -5,8 +5,9 @@ from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import ProxyAgent, RandomAgent
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -43,15 +44,15 @@ def test_example_config():
# green agent 1
assert "client_2_green_user" in game.agents
assert isinstance(game.agents["client_2_green_user"], RandomAgent)
assert isinstance(game.agents["client_2_green_user"], ProbabilisticAgent)
# green agent 2
assert "client_1_green_user" in game.agents
assert isinstance(game.agents["client_1_green_user"], RandomAgent)
assert isinstance(game.agents["client_1_green_user"], ProbabilisticAgent)
# red agent
assert "client_1_data_manipulation_red_bot" in game.agents
assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent)
assert "data_manipulation_attacker" in game.agents
assert isinstance(game.agents["data_manipulation_attacker"], DataManipulationAgent)
# blue agent
assert "defender" in game.agents

View File

@@ -0,0 +1,25 @@
from pathlib import Path
from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.game import PrimaiteGame
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
return PrimaiteGame.from_config(cfg)
def test_thresholds():
"""Test that the game options can be parsed correctly."""
game = load_config(data_manipulation_config_path())
assert game.options.thresholds is not None

View File

@@ -0,0 +1,66 @@
import pytest
from primaite.game.agent.observations.observations import AclObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
@pytest.fixture(scope="function")
def simulation(example_network) -> Simulation:
sim = Simulation()
# set simulation network as example network
sim.network = example_network
return sim
def test_acl_observations(simulation):
"""Test the ACL rule observations."""
router: Router = simulation.network.get_node_by_hostname("router_1")
client_1: Computer = simulation.network.get_node_by_hostname("client_1")
server: Computer = simulation.network.get_node_by_hostname("server_1")
# quick set up of ntp
client_1.software_manager.install(NTPClient)
ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient")
ntp_client.configure(server.network_interface.get(1).ip_address)
server.software_manager.install(NTPServer)
# add router acl rule
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1)
acl_obs = AclObservation(
where=["network", "nodes", router.hostname, "acl", "acl"],
node_ip_to_id={},
ports=["NTP", "HTTP", "POSTGRES_SERVER"],
protocols=["TCP", "UDP", "ICMP"],
)
observation_space = acl_obs.observe(simulation.describe_state())
assert observation_space.get(1) is not None
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1)
assert rule_obs.get("permission") == 1 # permit = 1 deny = 2
assert rule_obs.get("source_node_id") == 1 # applies to all source nodes
assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes
assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs)
assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2
assert rule_obs.get("protocol") == 1 # 1 = No Protocol
router.acl.remove_rule(1)
observation_space = acl_obs.observe(simulation.describe_state())
assert observation_space.get(1) is not None
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
assert rule_obs.get("position") == 0
assert rule_obs.get("permission") == 0
assert rule_obs.get("source_node_id") == 0
assert rule_obs.get("dest_node_id") == 0
assert rule_obs.get("source_port") == 0
assert rule_obs.get("dest_port") == 0
assert rule_obs.get("protocol") == 0

View File

@@ -0,0 +1,70 @@
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.sim_container import Simulation
@pytest.fixture(scope="function")
def simulation(example_network) -> Simulation:
sim = Simulation()
# set simulation network as example network
sim.network = example_network
return sim
def test_file_observation(simulation):
"""Test the file observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
# create a file on the pc
file = pc.file_system.create_file(file_name="dog.png")
dog_file_obs = FileObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
)
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
observation_state = dog_file_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 1 # good initial
file.corrupt()
observation_state = dog_file_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 1 # scan file so this changes
file.scan()
file.apply_timestep(0) # apply time step
observation_state = dog_file_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 3 # corrupted
def test_folder_observation(simulation):
"""Test the folder observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
# create a file and folder on the pc
folder = pc.file_system.create_folder("test_folder")
file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder")
root_folder_obs = FolderObservation(
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"]
)
assert root_folder_obs.space["health_status"] == spaces.Discrete(6)
observation_state = root_folder_obs.observe(simulation.describe_state())
assert observation_state.get("FILES") is not None
assert observation_state.get("health_status") == 1
file.corrupt() # corrupt just the file
observation_state = root_folder_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 1 # scan folder to change this
folder.scan()
for i in range(folder.scan_duration + 1):
folder.apply_timestep(i) # apply as many timesteps as needed for a scan
observation_state = root_folder_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too

View File

@@ -0,0 +1,73 @@
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.observations import LinkObservation
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import Link, Node
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.sim_container import Simulation
@pytest.fixture(scope="function")
def simulation() -> Simulation:
sim = Simulation()
network = Network()
# Create Computer
computer = Computer(
hostname="computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
computer.power_on()
# Create Server
server = Server(
hostname="server",
ip_address="192.168.1.3",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
start_up_duration=0,
)
server.power_on()
# Connect Computer and Server
network.connect(computer.network_interface[1], server.network_interface[1])
# Should be linked
assert next(iter(network.links.values())).is_up
assert computer.ping(server.network_interface.get(1).ip_address)
# set simulation network as example network
sim.network = network
return sim
def test_link_observation(simulation):
"""Test the link observation."""
# get a link
link: Link = next(iter(simulation.network.links.values()))
computer: Computer = simulation.network.get_node_by_hostname("computer")
server: Server = simulation.network.get_node_by_hostname("server")
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
link_obs = LinkObservation(where=["network", "links", link.uuid])
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
observation_state = link_obs.observe(simulation.describe_state())
assert observation_state.get("PROTOCOLS") is not None
assert observation_state["PROTOCOLS"]["ALL"] == 0
computer.ping(server.network_interface.get(1).ip_address)
observation_state = link_obs.observe(simulation.describe_state())
assert observation_state["PROTOCOLS"]["ALL"] == 1

View File

@@ -0,0 +1,97 @@
from pathlib import Path
from typing import Union
import pytest
import yaml
from gymnasium import spaces
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.nmne import CAPTURE_NMNE
from primaite.simulator.sim_container import Simulation
from tests import TEST_ASSETS_ROOT
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
with open(config_path, "r") as f:
cfg = yaml.safe_load(f)
return PrimaiteGame.from_config(cfg)
@pytest.fixture(scope="function")
def simulation(example_network) -> Simulation:
sim = Simulation()
# set simulation network as example network
sim.network = example_network
return sim
def test_nic(simulation):
"""Test the NIC observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic: NIC = pc.network_interface[1]
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 1 # enabled
assert observation_state.get("NMNE") is not None
assert observation_state["NMNE"].get("inbound") == 0
assert observation_state["NMNE"].get("outbound") == 0
nic.disable()
observation_state = nic_obs.observe(simulation.describe_state())
assert observation_state.get("nic_status") == 2 # disabled
def test_nic_categories(simulation):
"""Test the NIC observation nmne count categories."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
assert nic_obs.high_nmne_threshold == 10 # default
assert nic_obs.med_nmne_threshold == 5 # default
assert nic_obs.low_nmne_threshold == 0 # default
nic_obs = NicObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=6,
high_nmne_threshold=9,
)
assert nic_obs.high_nmne_threshold == 9
assert nic_obs.med_nmne_threshold == 6
assert nic_obs.low_nmne_threshold == 3
with pytest.raises(Exception):
# should throw an error
NicObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=9,
med_nmne_threshold=6,
high_nmne_threshold=9,
)
with pytest.raises(Exception):
# should throw an error
NicObservation(
where=["network", "nodes", pc.hostname, "NICs", 1],
low_nmne_threshold=3,
med_nmne_threshold=9,
high_nmne_threshold=9,
)

View File

@@ -0,0 +1,46 @@
import copy
from uuid import uuid4
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.node_observations import NodeObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.sim_container import Simulation
@pytest.fixture(scope="function")
def simulation(example_network) -> Simulation:
sim = Simulation()
# set simulation network as example network
sim.network = example_network
return sim
def test_node_observation(simulation):
"""Test a Node observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
node_obs = NodeObservation(where=["network", "nodes", pc.hostname])
assert node_obs.space["operating_status"] == spaces.Discrete(5)
observation_state = node_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 1 # computer is on
assert observation_state.get("SERVICES") is not None
assert observation_state.get("FOLDERS") is not None
assert observation_state.get("NICS") is not None
# turn off computer
pc.power_off()
observation_state = node_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 4 # shutting down
for i in range(pc.shut_down_duration + 1):
pc.apply_timestep(i)
observation_state = node_obs.observe(simulation.describe_state())
assert observation_state.get("operating_status") == 2

View File

@@ -0,0 +1,70 @@
import pytest
from gymnasium import spaces
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
@pytest.fixture(scope="function")
def simulation(example_network) -> Simulation:
sim = Simulation()
# set simulation network as example network
sim.network = example_network
return sim
def test_service_observation(simulation):
"""Test the service observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
# install software on the computer
pc.software_manager.install(NTPServer)
ntp_server = pc.software_manager.software.get("NTPServer")
assert ntp_server
service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"])
assert service_obs.space["operating_status"] == spaces.Discrete(7)
assert service_obs.space["health_status"] == spaces.Discrete(5)
observation_state = service_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 0
assert observation_state.get("operating_status") == 1 # running
ntp_server.restart()
observation_state = service_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 0
assert observation_state.get("operating_status") == 6 # resetting
def test_application_observation(simulation):
"""Test the application observation."""
pc: Computer = simulation.network.get_node_by_hostname("client_1")
# install software on the computer
pc.software_manager.install(DatabaseClient)
web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser")
assert web_browser
app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"])
web_browser.close()
observation_state = app_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 0
assert observation_state.get("operating_status") == 2 # stopped
assert observation_state.get("num_executions") == 0
web_browser.run()
web_browser.scan() # scan to update health status
web_browser.get_webpage("test")
observation_state = app_obs.observe(simulation.describe_state())
assert observation_state.get("health_status") == 1
assert observation_state.get("operating_status") == 1 # running
assert observation_state.get("num_executions") == 1

View File

@@ -10,28 +10,14 @@
# 4. Check that the simulation has changed in the way that I expect.
# 5. Repeat for all actions.
from typing import Dict, Tuple
from typing import Tuple
import pytest
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
from primaite.game.agent.observations import ICSObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.network.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.web_server.web_server import WebServer
from primaite.simulator.system.software import SoftwareHealthState

View File

@@ -1,6 +1,6 @@
from gymnasium import spaces
from primaite.game.agent.observations import FileObservation
from primaite.game.agent.observations.file_system_observations import FileObservation
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.sim_container import Simulation

View File

@@ -1,4 +1,4 @@
from primaite.game.agent.observations import NicObservation
from primaite.game.agent.observations.nic_observations import NicObservation
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.network.nmne import set_nmne_config
from primaite.simulator.sim_container import Simulation
@@ -179,8 +179,8 @@ def test_capture_nmne_observations(uc2_network):
# Observe the current state of NMNEs from the NICs of both the database and web servers
state = sim.describe_state()
db_nic_obs = db_server_nic_obs.observe(state)["nmne"]
web_nic_obs = web_server_nic_obs.observe(state)["nmne"]
db_nic_obs = db_server_nic_obs.observe(state)["NMNE"]
web_nic_obs = web_server_nic_obs.observe(state)["NMNE"]
# Define expected NMNE values based on the iteration count
if i > 10:

View File

@@ -1,7 +1,8 @@
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ICSObservation, ObservationManager
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.observations.observations import ICSObservation
from primaite.game.agent.rewards import RewardFunction
from primaite.game.agent.scripted_agents import ProbabilisticAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
def test_probabilistic_agent():

View File

@@ -1,7 +1,9 @@
import pytest
from primaite.simulator.file_system.file import File
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.file_system.file_type import FileType
from primaite.simulator.file_system.folder import Folder
def test_create_folder_and_file(file_system):
@@ -14,8 +16,15 @@ def test_create_folder_and_file(file_system):
assert len(file_system.get_folder("test_folder").files) == 1
assert file_system.num_file_creations == 1
assert file_system.get_folder("test_folder").get_file("test_file.txt")
file_system.apply_timestep(0)
# num file creations should reset
assert file_system.num_file_creations == 0
file_system.show(full=True)
@@ -23,24 +32,37 @@ def test_create_file_no_folder(file_system):
"""Tests that creating a file without a folder creates a folder and sets that as the file's parent."""
file = file_system.create_file(file_name="test_file.txt", size=10)
assert len(file_system.folders) is 1
assert file_system.num_file_creations == 1
assert file_system.get_folder("root").get_file("test_file.txt") == file
assert file_system.get_folder("root").get_file("test_file.txt").file_type == FileType.TXT
assert file_system.get_folder("root").get_file("test_file.txt").size == 10
file_system.apply_timestep(0)
# num file creations should reset
assert file_system.num_file_creations == 0
file_system.show(full=True)
def test_delete_file(file_system):
"""Tests that a file can be deleted."""
file_system.create_file(file_name="test_file.txt")
file = file_system.create_file(file_name="test_file.txt")
assert len(file_system.folders) == 1
assert len(file_system.get_folder("root").files) == 1
file_system.delete_file(folder_name="root", file_name="test_file.txt")
assert file.num_access == 1
assert file_system.num_file_deletions == 1
assert len(file_system.folders) == 1
assert len(file_system.get_folder("root").files) == 0
assert len(file_system.get_folder("root").deleted_files) == 1
file_system.apply_timestep(0)
# num file deletions should reset
assert file_system.num_file_deletions == 0
file_system.show(full=True)
@@ -54,6 +76,7 @@ def test_delete_non_existent_file(file_system):
# deleting should not change how many files are in folder
file_system.delete_file(folder_name="root", file_name="does_not_exist!")
assert file_system.num_file_deletions == 0
# should still only be one folder
assert len(file_system.folders) == 1
@@ -96,6 +119,7 @@ def test_create_duplicate_file(file_system):
assert len(file_system.folders) is 2
file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
assert file_system.num_file_creations == 1
assert len(file_system.get_folder("test_folder").files) == 1
@@ -103,6 +127,7 @@ def test_create_duplicate_file(file_system):
file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
assert len(file_system.get_folder("test_folder").files) == 1
assert file_system.num_file_creations == 1
file_system.show(full=True)
@@ -136,13 +161,24 @@ def test_move_file(file_system):
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 0
assert file_system.num_file_deletions == 0
assert file_system.num_file_creations == 1
file_system.move_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder")
assert file_system.num_file_deletions == 1
assert file_system.num_file_creations == 2
assert file.num_access == 1
assert len(file_system.get_folder("src_folder").files) == 0
assert len(file_system.get_folder("dst_folder").files) == 1
assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid
file_system.apply_timestep(0)
# num file creations and deletions should reset
assert file_system.num_file_creations == 0
assert file_system.num_file_deletions == 0
file_system.show(full=True)
@@ -152,17 +188,25 @@ def test_copy_file(file_system):
file_system.create_folder(folder_name="dst_folder")
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True)
assert file_system.num_file_creations == 1
original_uuid = file.uuid
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 0
file_system.copy_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder")
assert file_system.num_file_creations == 2
assert file.num_access == 1
assert len(file_system.get_folder("src_folder").files) == 1
assert len(file_system.get_folder("dst_folder").files) == 1
assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid
file_system.apply_timestep(0)
# num file creations should reset
assert file_system.num_file_creations == 0
file_system.show(full=True)
@@ -172,13 +216,17 @@ def test_get_file(file_system):
file1: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
file2: File = file_system.create_file(file_name="test_file2.txt", folder_name="test_folder")
folder.remove_file(file2)
file_system.delete_file("test_folder", "test_file2.txt")
# file 2 was accessed before being deleted
assert file2.num_access == 1
assert file_system.get_file_by_id(file_uuid=file1.uuid, folder_uuid=folder.uuid) is not None
assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid) is None
assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid, include_deleted=True) is not None
assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None
assert file2.num_access == 1 # cannot access deleted file
file_system.delete_folder(folder_name="test_folder")
assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None