Merged PR 301: #2350: Confirm action / observation space conforms to CAOS v0.7
## Summary ### **work related to v0.7 CAOS** - Split observations.py into: - agent_observations.py - file_system_observations.py - node_observations.py - nic_observations.py - observation_manager.py - observations.py - software_observations.py - added tests to ensure that the observations align with [QTSL-820-2450 - ARCD Track 2 Common Action Observation Space Definition v0.7](https://nscuk.sharepoint.com/❌/r/sites/SSE32ARCDIDT/Shared%20Documents/General/ARCD/Architecture%20%26%20Design%20Documentation/Common/CAOS%20Related%20Documents/QTSL-820-2450%20-%20ARCD%20Track%202%20Common%20Action%20Observation%20Space%20Definition%20v0.7.xlsx?d=wee5713d8640b4b5bb3cb5624936e417e&csf=1&web=1&e=lByVQ5) ### preparation for v0.8 CAOS WILL NOT AFFECT OBSERVATION SPACE FOR V0.7 **DO NOT PANIC** these features are needed for v0.8 - integrated `num_access` to file (not used yet in file observations) - integrated `num_file_deletions` and `num_file_creations` to file_system (not used yet in node observations) ## Test process *How have you tested this (if applicable)?* ## Checklist - [X] PR is linked to a **work item** - [X] **acceptance criteria** of linked ticket are met - [X] performed **self-review** of the code - [X] written **tests** for any new functionality added with this PR - [ ] updated the **documentation** if this PR changes or adds functionality - [ ] written/updated **design docs** if this PR implements new functionality - [ ] updated the **change log** - [X] ran **pre-commit** checks for code style - [ ] attended to any **TO-DOs** left in the code Related work items: #2350
This commit is contained in:
@@ -23,6 +23,11 @@ This section defines high-level settings that apply across the game, currently i
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
``max_episode_length``
|
||||
----------------------
|
||||
@@ -44,3 +49,8 @@ See :ref:`List of Ports <List of Ports>` for a list of ports.
|
||||
A list of protocols that the Reinforcement Learning agent(s) are able to see in the observation space.
|
||||
|
||||
See :ref:`List of IPProtocols <List of IPProtocols>` for a list of protocols.
|
||||
|
||||
``thresholds``
|
||||
--------------
|
||||
|
||||
These are used to determine the thresholds of high, medium and low categories for counted observation occurrences.
|
||||
|
||||
@@ -22,14 +22,17 @@ io_settings:
|
||||
game:
|
||||
max_episode_length: 256
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
|
||||
@@ -6,7 +6,7 @@ from gymnasium.core import ActType, ObsType
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -146,23 +146,10 @@ class AbstractAgent(ABC):
|
||||
class AbstractScriptedAgent(AbstractAgent):
|
||||
"""Base class for actors which generate their own behaviour."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Sample the action space randomly.
|
||||
|
||||
:param obs: Current observation for this agent, not used in RandomAgent
|
||||
:type obs: ObsType
|
||||
:param timestep: The current simulation timestep, not used in RandomAgent
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
"""Return an action to be taken in the environment."""
|
||||
return super().get_action(obs=obs, timestep=timestep)
|
||||
|
||||
|
||||
class ProxyAgent(AbstractAgent):
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
0
src/primaite/game/agent/observations/__init__.py
Normal file
0
src/primaite/game/agent/observations/__init__.py
Normal file
188
src/primaite/game/agent/observations/agent_observations.py
Normal file
188
src/primaite/game/agent/observations/agent_observations.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.game.agent.observations.observations import (
|
||||
AbstractObservation,
|
||||
AclObservation,
|
||||
ICSObservation,
|
||||
LinkObservation,
|
||||
NullObservation,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class UC2BlueObservation(AbstractObservation):
|
||||
"""Container for all observations used by the blue agent in UC2.
|
||||
|
||||
TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler
|
||||
for the purpose of compiling several observation components.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodes: List[NodeObservation],
|
||||
links: List[LinkObservation],
|
||||
acl: AclObservation,
|
||||
ics: ICSObservation,
|
||||
where: Optional[List[str]] = None,
|
||||
) -> None:
|
||||
"""Initialise UC2 blue observation.
|
||||
|
||||
:param nodes: List of node observations
|
||||
:type nodes: List[NodeObservation]
|
||||
:param links: List of link observations
|
||||
:type links: List[LinkObservation]
|
||||
:param acl: The Access Control List observation
|
||||
:type acl: AclObservation
|
||||
:param ics: The ICS observation
|
||||
:type ics: ICSObservation
|
||||
:param where: Where in the simulation state dict to find information. Not used in this particular observation
|
||||
because it only compiles other observations and doesn't contribute any new information, defaults to None
|
||||
:type where: Optional[List[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
self.links: List[LinkObservation] = links
|
||||
self.acl: AclObservation = acl
|
||||
self.ics: ICSObservation = ics
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
"LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)},
|
||||
"ACL": self.acl.default_observation,
|
||||
"ICS": self.ics.default_observation,
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)}
|
||||
obs["ACL"] = self.acl.observe(state)
|
||||
obs["ICS"] = self.ics.observe(state)
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""
|
||||
Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
"LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}),
|
||||
"ACL": self.acl.space,
|
||||
"ICS": self.ics.space,
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation":
|
||||
"""Create UC2 blue observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes,
|
||||
links, ACL and ICS observations.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed UC2 blue observation
|
||||
:rtype: UC2BlueObservation
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
|
||||
num_services_per_node = config["num_services_per_node"]
|
||||
num_folders_per_node = config["num_folders_per_node"]
|
||||
num_files_per_folder = config["num_files_per_folder"]
|
||||
num_nics_per_node = config["num_nics_per_node"]
|
||||
nodes = [
|
||||
NodeObservation.from_config(
|
||||
config=n,
|
||||
game=game,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
for n in node_configs
|
||||
]
|
||||
|
||||
link_configs = config["links"]
|
||||
links = [LinkObservation.from_config(config=link, game=game) for link in link_configs]
|
||||
|
||||
acl_config = config["acl"]
|
||||
acl = AclObservation.from_config(config=acl_config, game=game)
|
||||
|
||||
ics_config = config["ics"]
|
||||
ics = ICSObservation.from_config(config=ics_config, game=game)
|
||||
new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"])
|
||||
return new
|
||||
|
||||
|
||||
class UC2RedObservation(AbstractObservation):
|
||||
"""Container for all observations used by the red agent in UC2."""
|
||||
|
||||
def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None:
|
||||
super().__init__()
|
||||
self.where: Optional[List[str]] = where
|
||||
self.nodes: List[NodeObservation] = nodes
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)}
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation":
|
||||
"""
|
||||
Create UC2 red observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this UC2 red observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
node_configs = config["nodes"]
|
||||
nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs]
|
||||
return cls(nodes=nodes, where=["network"])
|
||||
|
||||
|
||||
class UC2GreenObservation(NullObservation):
|
||||
"""Green agent observation. As the green agent's actions don't depend on the observation, this is empty."""
|
||||
|
||||
pass
|
||||
177
src/primaite/game/agent/observations/file_system_observations.py
Normal file
177
src/primaite/game/agent/observations/file_system_observations.py
Normal file
@@ -0,0 +1,177 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class FileObservation(AbstractObservation):
|
||||
"""Observation of a file on a node in the network."""
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""
|
||||
Initialise file observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>,'files',<file_name>]
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.default_observation: spaces.Space = {"health_status": 0}
|
||||
"Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted."
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
file_state = access_from_nested_dict(state, self.where)
|
||||
if file_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {"health_status": file_state["visible_status"]}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"health_status": spaces.Discrete(6)})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation":
|
||||
"""Create file observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this file observation.
|
||||
:type config: Dict
|
||||
:param game: _description_
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: _description_, defaults to None
|
||||
:type parent_where: _type_, optional
|
||||
:return: _description_
|
||||
:rtype: _type_
|
||||
"""
|
||||
return cls(where=parent_where + ["files", config["file_name"]])
|
||||
|
||||
|
||||
class FolderObservation(AbstractObservation):
|
||||
"""Folder observation, including files inside of the folder."""
|
||||
|
||||
def __init__(
|
||||
self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2
|
||||
) -> None:
|
||||
"""Initialise folder Observation, including files inside the folder.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this folder.
|
||||
A typical location for a file looks like this:
|
||||
['network','nodes',<node_hostname>,'file_system', 'folders',<folder_name>]
|
||||
:type where: Optional[List[str]]
|
||||
:param max_files: As size of the space must remain static, define max files that can be in this folder
|
||||
, defaults to 5
|
||||
:type max_files: int, optional
|
||||
:param file_positions: Defines the positioning within the observation space of particular files. This ensures
|
||||
that even if new files are created, the existing files will always occupy the same space in the observation
|
||||
space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the
|
||||
observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same
|
||||
name, it will take the position defined in this dict. Defaults to {}
|
||||
:type file_positions: Dict[int, str], optional
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.files: List[FileObservation] = files
|
||||
while len(self.files) < num_files_per_folder:
|
||||
self.files.append(FileObservation())
|
||||
while len(self.files) > num_files_per_folder:
|
||||
truncated_file = self.files.pop()
|
||||
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.default_observation = {
|
||||
"health_status": 0,
|
||||
"FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)},
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
folder_state = access_from_nested_dict(state, self.where)
|
||||
if folder_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
health_status = folder_state["health_status"]
|
||||
|
||||
obs = {}
|
||||
|
||||
obs["health_status"] = health_status
|
||||
obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)}
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"health_status": spaces.Discrete(6),
|
||||
"FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2
|
||||
) -> "FolderObservation":
|
||||
"""Create folder observation from a config. Also creates child file observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this folder observation. Includes the name of the
|
||||
folder and the files inside of it.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this folder's
|
||||
parent node. A typical location for a node ``where`` can be:
|
||||
['network','nodes',<node_hostname>,'file_system']
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed folder observation
|
||||
:rtype: FolderObservation
|
||||
"""
|
||||
where = parent_where + ["folders", config["folder_name"]]
|
||||
|
||||
file_configs = config["files"]
|
||||
files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs]
|
||||
|
||||
return cls(where=where, files=files, num_files_per_folder=num_files_per_folder)
|
||||
188
src/primaite/game/agent/observations/nic_observations.py
Normal file
188
src/primaite/game/agent/observations/nic_observations.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NicObservation(AbstractObservation):
|
||||
"""Observation of a Network Interface Card (NIC) in the network."""
|
||||
|
||||
low_nmne_threshold: int = 0
|
||||
"""The minimum number of malicious network events to be considered low."""
|
||||
med_nmne_threshold: int = 5
|
||||
"""The minimum number of malicious network events to be considered medium."""
|
||||
high_nmne_threshold: int = 10
|
||||
"""The minimum number of malicious network events to be considered high."""
|
||||
|
||||
global CAPTURE_NMNE
|
||||
|
||||
@property
|
||||
def default_observation(self) -> Dict:
|
||||
"""The default NIC observation dict."""
|
||||
data = {"nic_status": 0}
|
||||
if CAPTURE_NMNE:
|
||||
data.update({"NMNE": {"inbound": 0, "outbound": 0}})
|
||||
|
||||
return data
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
low_nmne_threshold: Optional[int] = 0,
|
||||
med_nmne_threshold: Optional[int] = 5,
|
||||
high_nmne_threshold: Optional[int] = 10,
|
||||
) -> None:
|
||||
"""Initialise NIC observation.
|
||||
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<node_hostname>,'NICs',<nic_number>]
|
||||
If None, this denotes that the NIC does not exist and the observation will be populated with zeroes.
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
global CAPTURE_NMNE
|
||||
if CAPTURE_NMNE:
|
||||
self.nmne_inbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
self.nmne_outbound_last_step: int = 0
|
||||
"""NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets
|
||||
us find the difference."""
|
||||
|
||||
if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold:
|
||||
self._validate_nmne_categories(
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
|
||||
def _validate_nmne_categories(
|
||||
self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10
|
||||
):
|
||||
"""
|
||||
Validates the nmne threshold config.
|
||||
|
||||
If the configuration is valid, the thresholds will be set, otherwise, an exception is raised.
|
||||
|
||||
:param: low_nmne_threshold: The minimum number of malicious network events to be considered low
|
||||
:param: med_nmne_threshold: The minimum number of malicious network events to be considered medium
|
||||
:param: high_nmne_threshold: The minimum number of malicious network events to be considered high
|
||||
"""
|
||||
if high_nmne_threshold <= med_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater "
|
||||
f"than medium nmne count ({med_nmne_threshold})"
|
||||
)
|
||||
|
||||
if med_nmne_threshold <= low_nmne_threshold:
|
||||
raise Exception(
|
||||
f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater "
|
||||
f"than low nmne count ({low_nmne_threshold})"
|
||||
)
|
||||
|
||||
self.high_nmne_threshold = high_nmne_threshold
|
||||
self.med_nmne_threshold = med_nmne_threshold
|
||||
self.low_nmne_threshold = low_nmne_threshold
|
||||
|
||||
def _categorise_mne_count(self, nmne_count: int) -> int:
|
||||
"""
|
||||
Categorise the number of Malicious Network Events (NMNEs) into discrete bins.
|
||||
|
||||
This helps in classifying the severity or volume of MNEs into manageable levels for the agent.
|
||||
|
||||
Bins are defined as follows:
|
||||
- 0: No MNEs detected (0 events).
|
||||
- 1: Low number of MNEs (default 1-5 events).
|
||||
- 2: Moderate number of MNEs (default 6-10 events).
|
||||
- 3: High number of MNEs (default more than 10 events).
|
||||
|
||||
:param nmne_count: Number of MNEs detected.
|
||||
:return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count.
|
||||
"""
|
||||
if nmne_count > self.high_nmne_threshold:
|
||||
return 3
|
||||
elif nmne_count > self.med_nmne_threshold:
|
||||
return 2
|
||||
elif nmne_count > self.low_nmne_threshold:
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
nic_state = access_from_nested_dict(state, self.where)
|
||||
|
||||
if nic_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
else:
|
||||
obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2}
|
||||
if CAPTURE_NMNE:
|
||||
obs_dict.update({"NMNE": {}})
|
||||
direction_dict = nic_state["nmne"].get("direction", {})
|
||||
inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {})
|
||||
inbound_count = inbound_keywords.get("*", 0)
|
||||
outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {})
|
||||
outbound_count = outbound_keywords.get("*", 0)
|
||||
obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step)
|
||||
obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step)
|
||||
self.nmne_inbound_last_step = inbound_count
|
||||
self.nmne_outbound_last_step = outbound_count
|
||||
return obs_dict
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
space = spaces.Dict({"nic_status": spaces.Discrete(3)})
|
||||
|
||||
if CAPTURE_NMNE:
|
||||
space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)})
|
||||
|
||||
return space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation":
|
||||
"""Create NIC observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this NIC observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent
|
||||
node. A typical location for a node ``where`` can be: ['network','nodes',<node_hostname>]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:return: Constructed NIC observation
|
||||
:rtype: NicObservation
|
||||
"""
|
||||
low_nmne_threshold = None
|
||||
med_nmne_threshold = None
|
||||
high_nmne_threshold = None
|
||||
|
||||
if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"):
|
||||
threshold = game.options.thresholds["nmne"]
|
||||
|
||||
low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None
|
||||
med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None
|
||||
high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None
|
||||
|
||||
return cls(
|
||||
where=parent_where + ["NICs", config["nic_num"]],
|
||||
low_nmne_threshold=low_nmne_threshold,
|
||||
med_nmne_threshold=med_nmne_threshold,
|
||||
high_nmne_threshold=high_nmne_threshold,
|
||||
)
|
||||
200
src/primaite/game/agent/observations/node_observations.py
Normal file
200
src/primaite/game/agent/observations/node_observations.py
Normal file
@@ -0,0 +1,200 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.observations.file_system_observations import FolderObservation
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.observations.software_observation import ServiceObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class NodeObservation(AbstractObservation):
|
||||
"""Observation of a node in the network. Includes services, folders and NICs."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
where: Optional[Tuple[str]] = None,
|
||||
services: List[ServiceObservation] = [],
|
||||
folders: List[FolderObservation] = [],
|
||||
network_interfaces: List[NicObservation] = [],
|
||||
logon_status: bool = False,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> None:
|
||||
"""
|
||||
Configurable observation for a node in the simulation.
|
||||
|
||||
:param where: Where in the simulation state dictionary for find relevant information for this observation.
|
||||
A typical location for a node looks like this:
|
||||
['network','nodes',<hostname>]. If empty list, a default null observation will be output, defaults to []
|
||||
:type where: List[str], optional
|
||||
:param services: Mapping between position in observation space and service name, defaults to {}
|
||||
:type services: Dict[int,str], optional
|
||||
:param max_services: Max number of services that can be presented in observation space for this node
|
||||
, defaults to 2
|
||||
:type max_services: int, optional
|
||||
:param folders: Mapping between position in observation space and folder name, defaults to {}
|
||||
:type folders: Dict[int,str], optional
|
||||
:param max_folders: Max number of folders in this node's obs space, defaults to 2
|
||||
:type max_folders: int, optional
|
||||
:param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {}
|
||||
:type network_interfaces: Dict[int,str], optional
|
||||
:param max_nics: Max number of network interfaces in this node's obs space, defaults to 5
|
||||
:type max_nics: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
self.services: List[ServiceObservation] = services
|
||||
while len(self.services) < num_services_per_node:
|
||||
# add empty service observation without `where` parameter so it always returns default (blank) observation
|
||||
self.services.append(ServiceObservation())
|
||||
while len(self.services) > num_services_per_node:
|
||||
truncated_service = self.services.pop()
|
||||
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
|
||||
_LOGGER.warning(msg)
|
||||
# truncate service list
|
||||
|
||||
self.folders: List[FolderObservation] = folders
|
||||
# add empty folder observation without `where` parameter that will always return default (blank) observations
|
||||
while len(self.folders) < num_folders_per_node:
|
||||
self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder))
|
||||
while len(self.folders) > num_folders_per_node:
|
||||
truncated_folder = self.folders.pop()
|
||||
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.network_interfaces: List[NicObservation] = network_interfaces
|
||||
while len(self.network_interfaces) < num_nics_per_node:
|
||||
self.network_interfaces.append(NicObservation())
|
||||
while len(self.network_interfaces) > num_nics_per_node:
|
||||
truncated_nic = self.network_interfaces.pop()
|
||||
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
|
||||
_LOGGER.warning(msg)
|
||||
|
||||
self.logon_status: bool = logon_status
|
||||
|
||||
self.default_observation: Dict = {
|
||||
"SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)},
|
||||
"FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)},
|
||||
"NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)},
|
||||
"operating_status": 0,
|
||||
}
|
||||
if self.logon_status:
|
||||
self.default_observation["logon_status"] = 0
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
node_state = access_from_nested_dict(state, self.where)
|
||||
if node_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
obs = {}
|
||||
obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)}
|
||||
obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)}
|
||||
obs["operating_status"] = node_state["operating_state"]
|
||||
obs["NICS"] = {
|
||||
i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces)
|
||||
}
|
||||
|
||||
if self.logon_status:
|
||||
obs["logon_status"] = 0
|
||||
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
space_shape = {
|
||||
"SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}),
|
||||
"FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}),
|
||||
"operating_status": spaces.Discrete(5),
|
||||
"NICS": spaces.Dict(
|
||||
{i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)}
|
||||
),
|
||||
}
|
||||
if self.logon_status:
|
||||
space_shape["logon_status"] = spaces.Discrete(3)
|
||||
|
||||
return spaces.Dict(space_shape)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: Dict,
|
||||
game: "PrimaiteGame",
|
||||
parent_where: Optional[List[str]] = None,
|
||||
num_services_per_node: int = 2,
|
||||
num_folders_per_node: int = 2,
|
||||
num_files_per_folder: int = 2,
|
||||
num_nics_per_node: int = 2,
|
||||
) -> "NodeObservation":
|
||||
"""Create node observation from a config. Also creates child service, folder and NIC observations.
|
||||
|
||||
:param config: Dictionary containing the configuration for this node observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary to find the information about this node's parent
|
||||
network. A typical location for it would be: ['network',]
|
||||
:type parent_where: Optional[List[str]]
|
||||
:param num_services_per_node: How many spaces for services are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_services_per_node: int, optional
|
||||
:param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_folders_per_node: int, optional
|
||||
:param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static
|
||||
observation size) , defaults to 2
|
||||
:type num_files_per_folder: int, optional
|
||||
:return: Constructed node observation
|
||||
:rtype: NodeObservation
|
||||
"""
|
||||
node_hostname = config["node_hostname"]
|
||||
if parent_where is None:
|
||||
where = ["network", "nodes", node_hostname]
|
||||
else:
|
||||
where = parent_where + ["nodes", node_hostname]
|
||||
|
||||
svc_configs = config.get("services", {})
|
||||
services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs]
|
||||
folder_configs = config.get("folders", {})
|
||||
folders = [
|
||||
FolderObservation.from_config(
|
||||
config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder
|
||||
)
|
||||
for c in folder_configs
|
||||
]
|
||||
# create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc.
|
||||
nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}]
|
||||
network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs]
|
||||
logon_status = config.get("logon_status", False)
|
||||
return cls(
|
||||
where=where,
|
||||
services=services,
|
||||
folders=folders,
|
||||
network_interfaces=network_interfaces,
|
||||
logon_status=logon_status,
|
||||
num_services_per_node=num_services_per_node,
|
||||
num_folders_per_node=num_folders_per_node,
|
||||
num_files_per_folder=num_files_per_folder,
|
||||
num_nics_per_node=num_nics_per_node,
|
||||
)
|
||||
73
src/primaite/game/agent/observations/observation_manager.py
Normal file
73
src/primaite/game/agent/observations/observation_manager.py
Normal file
@@ -0,0 +1,73 @@
|
||||
from typing import Dict, TYPE_CHECKING
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.observations.agent_observations import (
|
||||
UC2BlueObservation,
|
||||
UC2GreenObservation,
|
||||
UC2RedObservation,
|
||||
)
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class ObservationManager:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
|
||||
The observation space has the purpose of:
|
||||
1. Reading the outputted state from the PrimAITE Simulation.
|
||||
2. Selecting parts of the simulation state that are requested by the simulation config
|
||||
3. Formatting this information so an agent can use it to make decisions.
|
||||
"""
|
||||
|
||||
# TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed
|
||||
# to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next
|
||||
# refactor.
|
||||
|
||||
def __init__(self, observation: AbstractObservation) -> None:
|
||||
"""Initialise observation space.
|
||||
|
||||
:param observation: Observation object
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
self.current_observation: ObsType
|
||||
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
"""
|
||||
self.current_observation = self.obs.observe(state)
|
||||
return self.current_observation
|
||||
|
||||
@property
|
||||
def space(self) -> None:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
It should contain the key 'type' which selects which observation class to use (from a choice of:
|
||||
UC2BlueObservation, UC2RedObservation, UC2GreenObservation)
|
||||
The other key is 'options' which are passed to the constructor of the selected observation class.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
"""
|
||||
if config["type"] == "UC2BlueObservation":
|
||||
return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2RedObservation":
|
||||
return cls(UC2RedObservation.from_config(config.get("options", {}), game=game))
|
||||
elif config["type"] == "UC2GreenObservation":
|
||||
return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game))
|
||||
else:
|
||||
raise ValueError("Observation space type invalid")
|
||||
309
src/primaite/game/agent/observations/observations.py
Normal file
309
src/primaite/game/agent/observations/observations.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""Manages the observation space for the agent."""
|
||||
from abc import ABC, abstractmethod
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class AbstractObservation(ABC):
|
||||
"""Abstract class for an observation space component."""
|
||||
|
||||
@abstractmethod
|
||||
def observe(self, state: Dict) -> Any:
|
||||
"""
|
||||
Return an observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Any
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame"):
|
||||
"""Create this observation space component form a serialised format.
|
||||
|
||||
The `game` parameter is for a the PrimaiteGame object that spawns this component.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class LinkObservation(AbstractObservation):
|
||||
"""Observation of a link in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}}
|
||||
"Default observation is what should be returned when the link doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise link observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'servics', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
link_state = access_from_nested_dict(state, self.where)
|
||||
if link_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
bandwidth = link_state["bandwidth"]
|
||||
load = link_state["current_load"]
|
||||
if load == 0:
|
||||
utilisation_category = 0
|
||||
else:
|
||||
utilisation_fraction = load / bandwidth
|
||||
# 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100%
|
||||
utilisation_category = int(utilisation_fraction * 9) + 1
|
||||
|
||||
# TODO: once the links support separte load per protocol, this needs amendment to reflect that.
|
||||
return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})})
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation":
|
||||
"""Create link observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this link observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Constructed link observation
|
||||
:rtype: LinkObservation
|
||||
"""
|
||||
return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]])
|
||||
|
||||
|
||||
class AclObservation(AbstractObservation):
|
||||
"""Observation of an Access Control List (ACL) in the network."""
|
||||
|
||||
# TODO: should where be optional, and we can use where=None to pad the observation space?
|
||||
# definitely the current approach does not support tracking files that aren't specified by name, for example
|
||||
# if a file is created at runtime, we have currently got no way of telling the observation space to track it.
|
||||
# this needs adding, but not for the MVP.
|
||||
def __init__(
|
||||
self,
|
||||
node_ip_to_id: Dict[str, int],
|
||||
ports: List[int],
|
||||
protocols: List[str],
|
||||
where: Optional[Tuple[str]] = None,
|
||||
num_rules: int = 10,
|
||||
) -> None:
|
||||
"""Initialise ACL observation.
|
||||
|
||||
:param node_ip_to_id: Mapping between IP address and ID.
|
||||
:type node_ip_to_id: Dict[str, int]
|
||||
:param ports: List of ports which are part of the game that define the ordering when converting to an ID
|
||||
:type ports: List[int]
|
||||
:param protocols: List of protocols which are part of the game, defines ordering when converting to an ID
|
||||
:type protocols: list[str]
|
||||
:param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical
|
||||
example may look like this:
|
||||
['network','nodes',<router_hostname>,'acl','acl']
|
||||
:type where: Optional[Tuple[str]], optional
|
||||
:param num_rules: , defaults to 10
|
||||
:type num_rules: int, optional
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
self.num_rules: int = num_rules
|
||||
self.node_to_id: Dict[str, int] = node_ip_to_id
|
||||
"List of node IP addresses, order in this list determines how they are converted to an ID"
|
||||
self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)}
|
||||
"List of ports which are part of the game that define the ordering when converting to an ID"
|
||||
self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)}
|
||||
"List of protocols which are part of the game, defines ordering when converting to an ID"
|
||||
self.default_observation: Dict = {
|
||||
i
|
||||
+ 1: {
|
||||
"position": i,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
acl_state: Dict = access_from_nested_dict(state, self.where)
|
||||
if acl_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
|
||||
# TODO: what if the ACL has more rules than num of max rules for obs space
|
||||
obs = {}
|
||||
acl_items = dict(acl_state.items())
|
||||
i = 1 # don't show rule 0 for compatibility reasons.
|
||||
while i < self.num_rules + 1:
|
||||
rule_state = acl_items[i]
|
||||
if rule_state is None:
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": 0,
|
||||
"source_node_id": 0,
|
||||
"source_port": 0,
|
||||
"dest_node_id": 0,
|
||||
"dest_port": 0,
|
||||
"protocol": 0,
|
||||
}
|
||||
else:
|
||||
src_ip = rule_state["src_ip_address"]
|
||||
src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)]
|
||||
dst_ip = rule_state["dst_ip_address"]
|
||||
dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)]
|
||||
src_port = rule_state["src_port"]
|
||||
src_port_id = 1 if src_port is None else self.port_to_id[src_port]
|
||||
dst_port = rule_state["dst_port"]
|
||||
dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port]
|
||||
protocol = rule_state["protocol"]
|
||||
protocol_id = 1 if protocol is None else self.protocol_to_id[protocol]
|
||||
obs[i] = {
|
||||
"position": i - 1,
|
||||
"permission": rule_state["action"],
|
||||
"source_node_id": src_node_id,
|
||||
"source_port": src_port_id,
|
||||
"dest_node_id": dst_node_ip,
|
||||
"dest_port": dst_port_id,
|
||||
"protocol": protocol_id,
|
||||
}
|
||||
i += 1
|
||||
return obs
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape.
|
||||
|
||||
:return: Gymnasium space
|
||||
:rtype: spaces.Space
|
||||
"""
|
||||
return spaces.Dict(
|
||||
{
|
||||
i
|
||||
+ 1: spaces.Dict(
|
||||
{
|
||||
"position": spaces.Discrete(self.num_rules),
|
||||
"permission": spaces.Discrete(3),
|
||||
# adding two to lengths is to account for reserved values 0 (unused) and 1 (any)
|
||||
"source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"source_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2),
|
||||
"dest_port": spaces.Discrete(len(self.port_to_id) + 2),
|
||||
"protocol": spaces.Discrete(len(self.protocol_to_id) + 2),
|
||||
}
|
||||
)
|
||||
for i in range(self.num_rules)
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation":
|
||||
"""Generate ACL observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this ACL observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:return: Observation object
|
||||
:rtype: AclObservation
|
||||
"""
|
||||
max_acl_rules = config["options"]["max_acl_rules"]
|
||||
node_ip_to_idx = {}
|
||||
for ip_idx, ip_map_config in enumerate(config["ip_address_order"]):
|
||||
node_ref = ip_map_config["node_hostname"]
|
||||
nic_num = ip_map_config["nic_num"]
|
||||
node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]]
|
||||
nic_obj = node_obj.network_interface[nic_num]
|
||||
node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2
|
||||
|
||||
router_hostname = config["router_hostname"]
|
||||
return cls(
|
||||
node_ip_to_id=node_ip_to_idx,
|
||||
ports=game.options.ports,
|
||||
protocols=game.options.protocols,
|
||||
where=["network", "nodes", router_hostname, "acl", "acl"],
|
||||
num_rules=max_acl_rules,
|
||||
)
|
||||
|
||||
|
||||
class NullObservation(AbstractObservation):
|
||||
"""Null observation, returns a single 0 value for the observation space."""
|
||||
|
||||
def __init__(self, where: Optional[List[str]] = None):
|
||||
"""Initialise null observation."""
|
||||
self.default_observation: Dict = {}
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation."""
|
||||
return 0
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Discrete(1)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation":
|
||||
"""
|
||||
Create null observation from a config.
|
||||
|
||||
The parameters are ignored, they are here to match the signature of the other observation classes.
|
||||
"""
|
||||
return cls()
|
||||
|
||||
|
||||
class ICSObservation(NullObservation):
|
||||
"""ICS observation placeholder, currently not implemented so always returns a single 0."""
|
||||
|
||||
pass
|
||||
163
src/primaite/game/agent/observations/software_observation.py
Normal file
163
src/primaite/game/agent/observations/software_observation.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.observations import AbstractObservation
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.game import PrimaiteGame
|
||||
|
||||
|
||||
class ServiceObservation(AbstractObservation):
|
||||
"""Observation of a service in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0}
|
||||
"Default observation is what should be returned when the service doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise service observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'services', <service_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
service_state = access_from_nested_dict(state, self.where)
|
||||
if service_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": service_state["operating_state"],
|
||||
"health_status": service_state["health_state_visible"],
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)})
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ServiceObservation":
|
||||
"""Create service observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:rtype: ServiceObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config["service_name"]])
|
||||
|
||||
|
||||
class ApplicationObservation(AbstractObservation):
|
||||
"""Observation of an application in the network."""
|
||||
|
||||
default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0}
|
||||
"Default observation is what should be returned when the application doesn't exist."
|
||||
|
||||
def __init__(self, where: Optional[Tuple[str]] = None) -> None:
|
||||
"""Initialise application observation.
|
||||
|
||||
:param where: Store information about where in the simulation state dictionary to find the relevant information.
|
||||
Optional. If None, this corresponds that the file does not exist and the observation will be populated with
|
||||
zeroes.
|
||||
|
||||
A typical location for a service looks like this:
|
||||
`['network','nodes',<node_hostname>,'applications', <application_name>]`
|
||||
:type where: Optional[List[str]]
|
||||
"""
|
||||
super().__init__()
|
||||
self.where: Optional[Tuple[str]] = where
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
"""Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
:return: Observation
|
||||
:rtype: Dict
|
||||
"""
|
||||
if self.where is None:
|
||||
return self.default_observation
|
||||
|
||||
app_state = access_from_nested_dict(state, self.where)
|
||||
if app_state is NOT_PRESENT_IN_STATE:
|
||||
return self.default_observation
|
||||
return {
|
||||
"operating_status": app_state["operating_state"],
|
||||
"health_status": app_state["health_state_visible"],
|
||||
"num_executions": self._categorise_num_executions(app_state["num_executions"]),
|
||||
}
|
||||
|
||||
@property
|
||||
def space(self) -> spaces.Space:
|
||||
"""Gymnasium space object describing the observation space shape."""
|
||||
return spaces.Dict(
|
||||
{
|
||||
"operating_status": spaces.Discrete(7),
|
||||
"health_status": spaces.Discrete(6),
|
||||
"num_executions": spaces.Discrete(4),
|
||||
}
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None
|
||||
) -> "ApplicationObservation":
|
||||
"""Create application observation from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this service observation.
|
||||
:type config: Dict
|
||||
:param game: Reference to the PrimaiteGame object that spawned this observation.
|
||||
:type game: PrimaiteGame
|
||||
:param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional.
|
||||
:type parent_where: Optional[List[str]], optional
|
||||
:return: Constructed service observation
|
||||
:rtype: ApplicationObservation
|
||||
"""
|
||||
return cls(where=parent_where + ["services", config["application_name"]])
|
||||
|
||||
@classmethod
|
||||
def _categorise_num_executions(cls, num_executions: int) -> int:
|
||||
"""
|
||||
Categorise the number of executions of an application.
|
||||
|
||||
Helps classify the number of application executions into different categories.
|
||||
|
||||
Current categories:
|
||||
- 0: Application is never executed
|
||||
- 1: Application is executed a low number of times (1-5)
|
||||
- 2: Application is executed often (6-10)
|
||||
- 3: Application is executed a high number of times (more than 10)
|
||||
|
||||
:param: num_executions: Number of times the application is executed
|
||||
"""
|
||||
if num_executions > 10:
|
||||
return 3
|
||||
elif num_executions > 5:
|
||||
return 2
|
||||
elif num_executions > 0:
|
||||
return 1
|
||||
return 0
|
||||
0
src/primaite/game/agent/scripted_agents/__init__.py
Normal file
0
src/primaite/game/agent/scripted_agents/__init__.py
Normal file
@@ -7,7 +7,7 @@ from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
21
src/primaite/game/agent/scripted_agents/random_agent.py
Normal file
21
src/primaite/game/agent/scripted_agents/random_agent.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite.game.agent.interface import AbstractScriptedAgent
|
||||
|
||||
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
|
||||
"""Sample the action space randomly.
|
||||
|
||||
:param obs: Current observation for this agent, not used in RandomAgent
|
||||
:type obs: ObsType
|
||||
:param timestep: The current simulation timestep, not used in RandomAgent
|
||||
:type timestep: int
|
||||
:return: Action formatted in CAOS format
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
@@ -1,16 +1,16 @@
|
||||
"""PrimAITE game - Encapsulates the simulation and agents."""
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, List, Tuple
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents import ProbabilisticAgent
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.simulator.network.hardware.base import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
@@ -67,8 +67,13 @@ class PrimaiteGameOptions(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
max_episode_length: int = 256
|
||||
"""Maximum number of episodes for the PrimAITE game."""
|
||||
ports: List[str]
|
||||
"""A whitelist of available ports in the simulation."""
|
||||
protocols: List[str]
|
||||
"""A whitelist of available protocols in the simulation."""
|
||||
thresholds: Optional[Dict] = {}
|
||||
"""A dict containing the thresholds used for determining what is acceptable during observations."""
|
||||
|
||||
|
||||
class PrimaiteGame:
|
||||
|
||||
@@ -38,6 +38,8 @@ class File(FileSystemItemABC):
|
||||
"The Path if real is True."
|
||||
sim_root: Optional[Path] = None
|
||||
"Root path of the simulation."
|
||||
num_access: int = 0
|
||||
"Number of times the file was accessed in the current step."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
@@ -93,11 +95,23 @@ class File(FileSystemItemABC):
|
||||
return os.path.getsize(self.sim_path)
|
||||
return self.sim_size
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep to the file.
|
||||
|
||||
:param timestep: The current timestep of the simulation.
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
# reset the number of accesses to 0
|
||||
self.num_access = 0
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""Produce a dictionary describing the current state of this object."""
|
||||
state = super().describe_state()
|
||||
state["size"] = self.size
|
||||
state["file_type"] = self.file_type.name
|
||||
state["num_access"] = self.num_access
|
||||
return state
|
||||
|
||||
def scan(self) -> bool:
|
||||
@@ -106,6 +120,7 @@ class File(FileSystemItemABC):
|
||||
self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}")
|
||||
return False
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}")
|
||||
self.visible_health_status = self.health_status
|
||||
@@ -162,6 +177,7 @@ class File(FileSystemItemABC):
|
||||
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.health_status = FileSystemItemHealthStatus.GOOD
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
@@ -176,6 +192,7 @@ class File(FileSystemItemABC):
|
||||
if self.health_status == FileSystemItemHealthStatus.GOOD:
|
||||
self.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
@@ -189,6 +206,7 @@ class File(FileSystemItemABC):
|
||||
if self.health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.health_status = FileSystemItemHealthStatus.GOOD
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}")
|
||||
return True
|
||||
@@ -199,6 +217,7 @@ class File(FileSystemItemABC):
|
||||
self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}")
|
||||
return False
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
self.deleted = True
|
||||
self.sys_log.info(f"File deleted {self.folder_name}/{self.name}")
|
||||
return True
|
||||
|
||||
@@ -28,6 +28,10 @@ class FileSystem(SimComponent):
|
||||
"Instance of SysLog used to create system logs."
|
||||
sim_root: Path
|
||||
"Root path of the simulation."
|
||||
num_file_creations: int = 0
|
||||
"Number of file creations in the current step."
|
||||
num_file_deletions: int = 0
|
||||
"Number of file deletions in the current step."
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -264,6 +268,8 @@ class FileSystem(SimComponent):
|
||||
)
|
||||
folder.add_file(file)
|
||||
self._file_request_manager.add_request(name=file.name, request_type=RequestType(func=file._request_manager))
|
||||
# increment file creation
|
||||
self.num_file_creations += 1
|
||||
return file
|
||||
|
||||
def get_file(self, folder_name: str, file_name: str, include_deleted: Optional[bool] = False) -> Optional[File]:
|
||||
@@ -324,6 +330,8 @@ class FileSystem(SimComponent):
|
||||
if folder:
|
||||
file = folder.get_file(file_name)
|
||||
if file:
|
||||
# increment file creation
|
||||
self.num_file_deletions += 1
|
||||
folder.remove_file(file)
|
||||
return True
|
||||
return False
|
||||
@@ -355,15 +363,14 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
file = self.get_file(folder_name=src_folder_name, file_name=src_file_name)
|
||||
if file:
|
||||
src_folder = file.folder
|
||||
|
||||
# remove file from src
|
||||
src_folder.remove_file(file)
|
||||
self.delete_file(folder_name=file.folder_name, file_name=file.name)
|
||||
dst_folder = self.get_folder(folder_name=dst_folder_name)
|
||||
if not dst_folder:
|
||||
dst_folder = self.create_folder(dst_folder_name)
|
||||
# add file to dst
|
||||
dst_folder.add_file(file)
|
||||
self.num_file_creations += 1
|
||||
if file.real:
|
||||
old_sim_path = file.sim_path
|
||||
file.sim_path = file.sim_root / file.path
|
||||
@@ -391,6 +398,10 @@ class FileSystem(SimComponent):
|
||||
folder_name=dst_folder.name,
|
||||
**file.model_dump(exclude={"uuid", "folder_id", "folder_name", "sim_path"}),
|
||||
)
|
||||
self.num_file_creations += 1
|
||||
# increment access counter
|
||||
file.num_access += 1
|
||||
|
||||
dst_folder.add_file(file_copy, force=True)
|
||||
|
||||
if file.real:
|
||||
@@ -408,12 +419,20 @@ class FileSystem(SimComponent):
|
||||
state = super().describe_state()
|
||||
state["folders"] = {folder.name: folder.describe_state() for folder in self.folders.values()}
|
||||
state["deleted_folders"] = {folder.name: folder.describe_state() for folder in self.deleted_folders.values()}
|
||||
state["num_file_creations"] = self.num_file_creations
|
||||
state["num_file_deletions"] = self.num_file_deletions
|
||||
return state
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""Apply time step to FileSystem and its child folders and files."""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
# reset number of file creations
|
||||
self.num_file_creations = 0
|
||||
|
||||
# reset number of file deletions
|
||||
self.num_file_deletions = 0
|
||||
|
||||
# apply timestep to folders
|
||||
for folder_id in self.folders:
|
||||
self.folders[folder_id].apply_timestep(timestep=timestep)
|
||||
|
||||
@@ -138,7 +138,8 @@ class Folder(FileSystemItemABC):
|
||||
file = self.get_file_by_id(file_uuid=file_id)
|
||||
file.scan()
|
||||
if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT:
|
||||
self.visible_health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
self.health_status = FileSystemItemHealthStatus.CORRUPT
|
||||
self.visible_health_status = self.health_status
|
||||
|
||||
def _reveal_to_red_timestep(self) -> None:
|
||||
"""Apply reveal to red timestep."""
|
||||
|
||||
@@ -59,6 +59,16 @@ class Application(IOSoftware):
|
||||
)
|
||||
return state
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep to the application.
|
||||
|
||||
:param timestep: The current timestep of the simulation.
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
|
||||
self.num_executions = 0 # reset number of executions
|
||||
|
||||
def _can_perform_action(self) -> bool:
|
||||
"""
|
||||
Checks if the application can perform actions.
|
||||
|
||||
@@ -48,6 +48,7 @@ class DatabaseClient(Application):
|
||||
|
||||
def execute(self) -> bool:
|
||||
"""Execution definition for db client: perform a select query."""
|
||||
self.num_executions += 1 # trying to connect counts as an execution
|
||||
if self.connections:
|
||||
can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1])
|
||||
else:
|
||||
|
||||
@@ -193,6 +193,8 @@ class DataManipulationBot(Application):
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
|
||||
self.run()
|
||||
|
||||
self.num_executions += 1
|
||||
return self._application_loop()
|
||||
|
||||
def _application_loop(self) -> bool:
|
||||
|
||||
@@ -89,6 +89,8 @@ class WebBrowser(Application):
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
self.num_executions += 1 # trying to connect counts as an execution
|
||||
|
||||
# reset latest response
|
||||
self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND)
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@ from _pytest.monkeypatch import MonkeyPatch
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.agent.observations import ICSObservation, ObservationManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.observations.observations import ICSObservation
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.session import PrimaiteSession
|
||||
|
||||
@@ -5,8 +5,9 @@ from typing import Union
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.interface import ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -43,15 +44,15 @@ def test_example_config():
|
||||
|
||||
# green agent 1
|
||||
assert "client_2_green_user" in game.agents
|
||||
assert isinstance(game.agents["client_2_green_user"], RandomAgent)
|
||||
assert isinstance(game.agents["client_2_green_user"], ProbabilisticAgent)
|
||||
|
||||
# green agent 2
|
||||
assert "client_1_green_user" in game.agents
|
||||
assert isinstance(game.agents["client_1_green_user"], RandomAgent)
|
||||
assert isinstance(game.agents["client_1_green_user"], ProbabilisticAgent)
|
||||
|
||||
# red agent
|
||||
assert "client_1_data_manipulation_red_bot" in game.agents
|
||||
assert isinstance(game.agents["client_1_data_manipulation_red_bot"], DataManipulationAgent)
|
||||
assert "data_manipulation_attacker" in game.agents
|
||||
assert isinstance(game.agents["data_manipulation_attacker"], DataManipulationAgent)
|
||||
|
||||
# blue agent
|
||||
assert "defender" in game.agents
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
def test_thresholds():
|
||||
"""Test that the game options can be parsed correctly."""
|
||||
game = load_config(data_manipulation_config_path())
|
||||
|
||||
assert game.options.thresholds is not None
|
||||
@@ -0,0 +1,66 @@
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.observations.observations import AclObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def simulation(example_network) -> Simulation:
|
||||
sim = Simulation()
|
||||
|
||||
# set simulation network as example network
|
||||
sim.network = example_network
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def test_acl_observations(simulation):
|
||||
"""Test the ACL rule observations."""
|
||||
router: Router = simulation.network.get_node_by_hostname("router_1")
|
||||
client_1: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
server: Computer = simulation.network.get_node_by_hostname("server_1")
|
||||
|
||||
# quick set up of ntp
|
||||
client_1.software_manager.install(NTPClient)
|
||||
ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient")
|
||||
ntp_client.configure(server.network_interface.get(1).ip_address)
|
||||
server.software_manager.install(NTPServer)
|
||||
|
||||
# add router acl rule
|
||||
router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1)
|
||||
|
||||
acl_obs = AclObservation(
|
||||
where=["network", "nodes", router.hostname, "acl", "acl"],
|
||||
node_ip_to_id={},
|
||||
ports=["NTP", "HTTP", "POSTGRES_SERVER"],
|
||||
protocols=["TCP", "UDP", "ICMP"],
|
||||
)
|
||||
|
||||
observation_space = acl_obs.observe(simulation.describe_state())
|
||||
assert observation_space.get(1) is not None
|
||||
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
||||
assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1)
|
||||
assert rule_obs.get("permission") == 1 # permit = 1 deny = 2
|
||||
assert rule_obs.get("source_node_id") == 1 # applies to all source nodes
|
||||
assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes
|
||||
assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs)
|
||||
assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2
|
||||
assert rule_obs.get("protocol") == 1 # 1 = No Protocol
|
||||
|
||||
router.acl.remove_rule(1)
|
||||
|
||||
observation_space = acl_obs.observe(simulation.describe_state())
|
||||
assert observation_space.get(1) is not None
|
||||
rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP
|
||||
assert rule_obs.get("position") == 0
|
||||
assert rule_obs.get("permission") == 0
|
||||
assert rule_obs.get("source_node_id") == 0
|
||||
assert rule_obs.get("dest_node_id") == 0
|
||||
assert rule_obs.get("source_port") == 0
|
||||
assert rule_obs.get("dest_port") == 0
|
||||
assert rule_obs.get("protocol") == 0
|
||||
@@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def simulation(example_network) -> Simulation:
|
||||
sim = Simulation()
|
||||
|
||||
# set simulation network as example network
|
||||
sim.network = example_network
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def test_file_observation(simulation):
|
||||
"""Test the file observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
# create a file on the pc
|
||||
file = pc.file_system.create_file(file_name="dog.png")
|
||||
|
||||
dog_file_obs = FileObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"]
|
||||
)
|
||||
|
||||
assert dog_file_obs.space["health_status"] == spaces.Discrete(6)
|
||||
|
||||
observation_state = dog_file_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 1 # good initial
|
||||
|
||||
file.corrupt()
|
||||
observation_state = dog_file_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 1 # scan file so this changes
|
||||
|
||||
file.scan()
|
||||
file.apply_timestep(0) # apply time step
|
||||
observation_state = dog_file_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 3 # corrupted
|
||||
|
||||
|
||||
def test_folder_observation(simulation):
|
||||
"""Test the folder observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
# create a file and folder on the pc
|
||||
folder = pc.file_system.create_folder("test_folder")
|
||||
file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder")
|
||||
|
||||
root_folder_obs = FolderObservation(
|
||||
where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"]
|
||||
)
|
||||
|
||||
assert root_folder_obs.space["health_status"] == spaces.Discrete(6)
|
||||
|
||||
observation_state = root_folder_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("FILES") is not None
|
||||
assert observation_state.get("health_status") == 1
|
||||
|
||||
file.corrupt() # corrupt just the file
|
||||
observation_state = root_folder_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 1 # scan folder to change this
|
||||
|
||||
folder.scan()
|
||||
for i in range(folder.scan_duration + 1):
|
||||
folder.apply_timestep(i) # apply as many timesteps as needed for a scan
|
||||
|
||||
observation_state = root_folder_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too
|
||||
@@ -0,0 +1,73 @@
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.observations import LinkObservation
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.base import Link, Node
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def simulation() -> Simulation:
|
||||
sim = Simulation()
|
||||
|
||||
network = Network()
|
||||
|
||||
# Create Computer
|
||||
computer = Computer(
|
||||
hostname="computer",
|
||||
ip_address="192.168.1.2",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
computer.power_on()
|
||||
|
||||
# Create Server
|
||||
server = Server(
|
||||
hostname="server",
|
||||
ip_address="192.168.1.3",
|
||||
subnet_mask="255.255.255.0",
|
||||
default_gateway="192.168.1.1",
|
||||
start_up_duration=0,
|
||||
)
|
||||
server.power_on()
|
||||
|
||||
# Connect Computer and Server
|
||||
network.connect(computer.network_interface[1], server.network_interface[1])
|
||||
|
||||
# Should be linked
|
||||
assert next(iter(network.links.values())).is_up
|
||||
|
||||
assert computer.ping(server.network_interface.get(1).ip_address)
|
||||
|
||||
# set simulation network as example network
|
||||
sim.network = network
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def test_link_observation(simulation):
|
||||
"""Test the link observation."""
|
||||
# get a link
|
||||
link: Link = next(iter(simulation.network.links.values()))
|
||||
|
||||
computer: Computer = simulation.network.get_node_by_hostname("computer")
|
||||
server: Server = simulation.network.get_node_by_hostname("server")
|
||||
|
||||
simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep
|
||||
|
||||
link_obs = LinkObservation(where=["network", "links", link.uuid])
|
||||
|
||||
assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("PROTOCOLS") is not None
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 0
|
||||
|
||||
computer.ping(server.network_interface.get(1).ip_address)
|
||||
|
||||
observation_state = link_obs.observe(simulation.describe_state())
|
||||
assert observation_state["PROTOCOLS"]["ALL"] == 1
|
||||
@@ -0,0 +1,97 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
|
||||
from primaite.simulator.network.nmne import CAPTURE_NMNE
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def simulation(example_network) -> Simulation:
|
||||
sim = Simulation()
|
||||
|
||||
# set simulation network as example network
|
||||
sim.network = example_network
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def test_nic(simulation):
|
||||
"""Test the NIC observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
nic: NIC = pc.network_interface[1]
|
||||
|
||||
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
|
||||
assert nic_obs.space["nic_status"] == spaces.Discrete(3)
|
||||
assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4)
|
||||
assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4)
|
||||
|
||||
observation_state = nic_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("nic_status") == 1 # enabled
|
||||
assert observation_state.get("NMNE") is not None
|
||||
assert observation_state["NMNE"].get("inbound") == 0
|
||||
assert observation_state["NMNE"].get("outbound") == 0
|
||||
|
||||
nic.disable()
|
||||
observation_state = nic_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("nic_status") == 2 # disabled
|
||||
|
||||
|
||||
def test_nic_categories(simulation):
|
||||
"""Test the NIC observation nmne count categories."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1])
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 10 # default
|
||||
assert nic_obs.med_nmne_threshold == 5 # default
|
||||
assert nic_obs.low_nmne_threshold == 0 # default
|
||||
|
||||
nic_obs = NicObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
)
|
||||
|
||||
assert nic_obs.high_nmne_threshold == 9
|
||||
assert nic_obs.med_nmne_threshold == 6
|
||||
assert nic_obs.low_nmne_threshold == 3
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NicObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=9,
|
||||
med_nmne_threshold=6,
|
||||
high_nmne_threshold=9,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
# should throw an error
|
||||
NicObservation(
|
||||
where=["network", "nodes", pc.hostname, "NICs", 1],
|
||||
low_nmne_threshold=3,
|
||||
med_nmne_threshold=9,
|
||||
high_nmne_threshold=9,
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
import copy
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.node_observations import NodeObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def simulation(example_network) -> Simulation:
|
||||
sim = Simulation()
|
||||
|
||||
# set simulation network as example network
|
||||
sim.network = example_network
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def test_node_observation(simulation):
|
||||
"""Test a Node observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
node_obs = NodeObservation(where=["network", "nodes", pc.hostname])
|
||||
|
||||
assert node_obs.space["operating_status"] == spaces.Discrete(5)
|
||||
|
||||
observation_state = node_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("operating_status") == 1 # computer is on
|
||||
|
||||
assert observation_state.get("SERVICES") is not None
|
||||
assert observation_state.get("FOLDERS") is not None
|
||||
assert observation_state.get("NICS") is not None
|
||||
|
||||
# turn off computer
|
||||
pc.power_off()
|
||||
observation_state = node_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("operating_status") == 4 # shutting down
|
||||
|
||||
for i in range(pc.shut_down_duration + 1):
|
||||
pc.apply_timestep(i)
|
||||
|
||||
observation_state = node_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("operating_status") == 2
|
||||
@@ -0,0 +1,70 @@
|
||||
import pytest
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def simulation(example_network) -> Simulation:
|
||||
sim = Simulation()
|
||||
|
||||
# set simulation network as example network
|
||||
sim.network = example_network
|
||||
|
||||
return sim
|
||||
|
||||
|
||||
def test_service_observation(simulation):
|
||||
"""Test the service observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
# install software on the computer
|
||||
pc.software_manager.install(NTPServer)
|
||||
|
||||
ntp_server = pc.software_manager.software.get("NTPServer")
|
||||
assert ntp_server
|
||||
|
||||
service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"])
|
||||
|
||||
assert service_obs.space["operating_status"] == spaces.Discrete(7)
|
||||
assert service_obs.space["health_status"] == spaces.Discrete(5)
|
||||
|
||||
observation_state = service_obs.observe(simulation.describe_state())
|
||||
|
||||
assert observation_state.get("health_status") == 0
|
||||
assert observation_state.get("operating_status") == 1 # running
|
||||
|
||||
ntp_server.restart()
|
||||
observation_state = service_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 0
|
||||
assert observation_state.get("operating_status") == 6 # resetting
|
||||
|
||||
|
||||
def test_application_observation(simulation):
|
||||
"""Test the application observation."""
|
||||
pc: Computer = simulation.network.get_node_by_hostname("client_1")
|
||||
# install software on the computer
|
||||
pc.software_manager.install(DatabaseClient)
|
||||
|
||||
web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser")
|
||||
assert web_browser
|
||||
|
||||
app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"])
|
||||
|
||||
web_browser.close()
|
||||
observation_state = app_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 0
|
||||
assert observation_state.get("operating_status") == 2 # stopped
|
||||
assert observation_state.get("num_executions") == 0
|
||||
|
||||
web_browser.run()
|
||||
web_browser.scan() # scan to update health status
|
||||
web_browser.get_webpage("test")
|
||||
observation_state = app_obs.observe(simulation.describe_state())
|
||||
assert observation_state.get("health_status") == 1
|
||||
assert observation_state.get("operating_status") == 1 # running
|
||||
assert observation_state.get("num_executions") == 1
|
||||
@@ -10,28 +10,14 @@
|
||||
# 4. Check that the simulation has changed in the way that I expect.
|
||||
# 5. Repeat for all actions.
|
||||
|
||||
from typing import Dict, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent
|
||||
from primaite.game.agent.observations import ICSObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from gymnasium import spaces
|
||||
|
||||
from primaite.game.agent.observations import FileObservation
|
||||
from primaite.game.agent.observations.file_system_observations import FileObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from primaite.game.agent.observations import NicObservation
|
||||
from primaite.game.agent.observations.nic_observations import NicObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
@@ -179,8 +179,8 @@ def test_capture_nmne_observations(uc2_network):
|
||||
|
||||
# Observe the current state of NMNEs from the NICs of both the database and web servers
|
||||
state = sim.describe_state()
|
||||
db_nic_obs = db_server_nic_obs.observe(state)["nmne"]
|
||||
web_nic_obs = web_server_nic_obs.observe(state)["nmne"]
|
||||
db_nic_obs = db_server_nic_obs.observe(state)["NMNE"]
|
||||
web_nic_obs = web_server_nic_obs.observe(state)["NMNE"]
|
||||
|
||||
# Define expected NMNE values based on the iteration count
|
||||
if i > 10:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ICSObservation, ObservationManager
|
||||
from primaite.game.agent.observations.observation_manager import ObservationManager
|
||||
from primaite.game.agent.observations.observations import ICSObservation
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.agent.scripted_agents import ProbabilisticAgent
|
||||
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
|
||||
|
||||
|
||||
def test_probabilistic_agent():
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.file_system.file_type import FileType
|
||||
from primaite.simulator.file_system.folder import Folder
|
||||
|
||||
|
||||
def test_create_folder_and_file(file_system):
|
||||
@@ -14,8 +16,15 @@ def test_create_folder_and_file(file_system):
|
||||
|
||||
assert len(file_system.get_folder("test_folder").files) == 1
|
||||
|
||||
assert file_system.num_file_creations == 1
|
||||
|
||||
assert file_system.get_folder("test_folder").get_file("test_file.txt")
|
||||
|
||||
file_system.apply_timestep(0)
|
||||
|
||||
# num file creations should reset
|
||||
assert file_system.num_file_creations == 0
|
||||
|
||||
file_system.show(full=True)
|
||||
|
||||
|
||||
@@ -23,24 +32,37 @@ def test_create_file_no_folder(file_system):
|
||||
"""Tests that creating a file without a folder creates a folder and sets that as the file's parent."""
|
||||
file = file_system.create_file(file_name="test_file.txt", size=10)
|
||||
assert len(file_system.folders) is 1
|
||||
assert file_system.num_file_creations == 1
|
||||
assert file_system.get_folder("root").get_file("test_file.txt") == file
|
||||
assert file_system.get_folder("root").get_file("test_file.txt").file_type == FileType.TXT
|
||||
assert file_system.get_folder("root").get_file("test_file.txt").size == 10
|
||||
|
||||
file_system.apply_timestep(0)
|
||||
|
||||
# num file creations should reset
|
||||
assert file_system.num_file_creations == 0
|
||||
|
||||
file_system.show(full=True)
|
||||
|
||||
|
||||
def test_delete_file(file_system):
|
||||
"""Tests that a file can be deleted."""
|
||||
file_system.create_file(file_name="test_file.txt")
|
||||
file = file_system.create_file(file_name="test_file.txt")
|
||||
assert len(file_system.folders) == 1
|
||||
assert len(file_system.get_folder("root").files) == 1
|
||||
|
||||
file_system.delete_file(folder_name="root", file_name="test_file.txt")
|
||||
assert file.num_access == 1
|
||||
assert file_system.num_file_deletions == 1
|
||||
assert len(file_system.folders) == 1
|
||||
assert len(file_system.get_folder("root").files) == 0
|
||||
assert len(file_system.get_folder("root").deleted_files) == 1
|
||||
|
||||
file_system.apply_timestep(0)
|
||||
|
||||
# num file deletions should reset
|
||||
assert file_system.num_file_deletions == 0
|
||||
|
||||
file_system.show(full=True)
|
||||
|
||||
|
||||
@@ -54,6 +76,7 @@ def test_delete_non_existent_file(file_system):
|
||||
|
||||
# deleting should not change how many files are in folder
|
||||
file_system.delete_file(folder_name="root", file_name="does_not_exist!")
|
||||
assert file_system.num_file_deletions == 0
|
||||
|
||||
# should still only be one folder
|
||||
assert len(file_system.folders) == 1
|
||||
@@ -96,6 +119,7 @@ def test_create_duplicate_file(file_system):
|
||||
|
||||
assert len(file_system.folders) is 2
|
||||
file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
|
||||
assert file_system.num_file_creations == 1
|
||||
|
||||
assert len(file_system.get_folder("test_folder").files) == 1
|
||||
|
||||
@@ -103,6 +127,7 @@ def test_create_duplicate_file(file_system):
|
||||
file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
|
||||
|
||||
assert len(file_system.get_folder("test_folder").files) == 1
|
||||
assert file_system.num_file_creations == 1
|
||||
|
||||
file_system.show(full=True)
|
||||
|
||||
@@ -136,13 +161,24 @@ def test_move_file(file_system):
|
||||
|
||||
assert len(file_system.get_folder("src_folder").files) == 1
|
||||
assert len(file_system.get_folder("dst_folder").files) == 0
|
||||
assert file_system.num_file_deletions == 0
|
||||
assert file_system.num_file_creations == 1
|
||||
|
||||
file_system.move_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder")
|
||||
assert file_system.num_file_deletions == 1
|
||||
assert file_system.num_file_creations == 2
|
||||
assert file.num_access == 1
|
||||
|
||||
assert len(file_system.get_folder("src_folder").files) == 0
|
||||
assert len(file_system.get_folder("dst_folder").files) == 1
|
||||
assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid
|
||||
|
||||
file_system.apply_timestep(0)
|
||||
|
||||
# num file creations and deletions should reset
|
||||
assert file_system.num_file_creations == 0
|
||||
assert file_system.num_file_deletions == 0
|
||||
|
||||
file_system.show(full=True)
|
||||
|
||||
|
||||
@@ -152,17 +188,25 @@ def test_copy_file(file_system):
|
||||
file_system.create_folder(folder_name="dst_folder")
|
||||
|
||||
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True)
|
||||
assert file_system.num_file_creations == 1
|
||||
original_uuid = file.uuid
|
||||
|
||||
assert len(file_system.get_folder("src_folder").files) == 1
|
||||
assert len(file_system.get_folder("dst_folder").files) == 0
|
||||
|
||||
file_system.copy_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder")
|
||||
assert file_system.num_file_creations == 2
|
||||
assert file.num_access == 1
|
||||
|
||||
assert len(file_system.get_folder("src_folder").files) == 1
|
||||
assert len(file_system.get_folder("dst_folder").files) == 1
|
||||
assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid
|
||||
|
||||
file_system.apply_timestep(0)
|
||||
|
||||
# num file creations should reset
|
||||
assert file_system.num_file_creations == 0
|
||||
|
||||
file_system.show(full=True)
|
||||
|
||||
|
||||
@@ -172,13 +216,17 @@ def test_get_file(file_system):
|
||||
file1: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
|
||||
file2: File = file_system.create_file(file_name="test_file2.txt", folder_name="test_folder")
|
||||
|
||||
folder.remove_file(file2)
|
||||
file_system.delete_file("test_folder", "test_file2.txt")
|
||||
# file 2 was accessed before being deleted
|
||||
assert file2.num_access == 1
|
||||
|
||||
assert file_system.get_file_by_id(file_uuid=file1.uuid, folder_uuid=folder.uuid) is not None
|
||||
assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid) is None
|
||||
assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid, include_deleted=True) is not None
|
||||
assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None
|
||||
|
||||
assert file2.num_access == 1 # cannot access deleted file
|
||||
|
||||
file_system.delete_folder(folder_name="test_folder")
|
||||
assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user