Merge branch 'dev' into bugfix/2676_NMNE_var_access

This commit is contained in:
Nick Todd
2024-07-15 09:27:11 +01:00
81 changed files with 11491 additions and 645 deletions

View File

@@ -11,74 +11,78 @@ schedules:
branches:
include:
- 'refs/heads/dev'
pool:
vmImage: ubuntu-latest
variables:
VERSION: ''
MAJOR_VERSION: ''
steps:
- checkout: self
persistCredentials: true
jobs:
- job: PrimAITE_Benchmark
timeoutInMinutes: 360 # 6-hour maximum
pool:
vmImage: ubuntu-latest
workspace:
clean: all
steps:
- checkout: self
persistCredentials: true
- script: |
VERSION=$(cat src/primaite/VERSION | tr -d '\n')
if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then
DATE=$(date +%Y%m%d)
echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION
fi
displayName: 'Update VERSION file for Dev Benchmark'
- script: |
VERSION=$(cat src/primaite/VERSION | tr -d '\n')
if [[ "$(Build.SourceBranch)" == "refs/heads/dev" ]]; then
DATE=$(date +%Y%m%d)
echo "${VERSION}+dev.${DATE}" > src/primaite/VERSION
fi
displayName: 'Update VERSION file for Dev Benchmark'
- script: |
VERSION=$(cat src/primaite/VERSION | tr -d '\n')
MAJOR_VERSION=$(echo $VERSION | cut -d. -f1)
echo "##vso[task.setvariable variable=VERSION]$VERSION"
echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION"
displayName: 'Set Version Variables'
- script: |
VERSION=$(cat src/primaite/VERSION | tr -d '\n')
MAJOR_VERSION=$(echo $VERSION | cut -d. -f1)
echo "##vso[task.setvariable variable=VERSION]$VERSION"
echo "##vso[task.setvariable variable=MAJOR_VERSION]$MAJOR_VERSION"
displayName: 'Set Version Variables'
- task: UsePythonVersion@0
inputs:
versionSpec: '3.11'
addToPath: true
- task: UsePythonVersion@0
inputs:
versionSpec: '3.11'
addToPath: true
- script: |
python -m pip install --upgrade pip
pip install -e .[dev,rl]
primaite setup
displayName: 'Install Dependencies'
- script: |
python -m pip install --upgrade pip
pip install -e .[dev,rl]
primaite setup
displayName: 'Install Dependencies'
- script: |
cd benchmark
python3 primaite_benchmark.py
cd ..
displayName: 'Run Benchmarking Script'
- script: |
set -e
cd benchmark
python3 primaite_benchmark.py
cd ..
displayName: 'Run Benchmarking Script'
- script: |
git config --global user.email "oss@dstl.gov.uk"
git config --global user.name "Defence Science and Technology Laboratory UK"
workingDirectory: $(System.DefaultWorkingDirectory)
displayName: 'Configure Git'
condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release'))
- script: |
git config --global user.email "oss@dstl.gov.uk"
git config --global user.name "Defence Science and Technology Laboratory UK"
workingDirectory: $(System.DefaultWorkingDirectory)
displayName: 'Configure Git'
condition: and(succeeded(), eq(variables['Build.Reason'], 'Manual'), startsWith(variables['Build.SourceBranch'], 'refs/heads/release'))
- script: |
git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/*
git commit -m "Automated benchmark output commit for version $(VERSION)"
git push origin HEAD:refs/heads/$(Build.SourceBranchName)
displayName: 'Commit and Push Benchmark Results'
workingDirectory: $(System.DefaultWorkingDirectory)
env:
GIT_CREDENTIALS: $(System.AccessToken)
condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release'))
- script: |
git add benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)/*
git commit -m "Automated benchmark output commit for version $(VERSION)"
git push origin HEAD:refs/heads/$(Build.SourceBranchName)
displayName: 'Commit and Push Benchmark Results'
workingDirectory: $(System.DefaultWorkingDirectory)
env:
GIT_CREDENTIALS: $(System.AccessToken)
condition: and(succeeded(), startsWith(variables['Build.SourceBranch'], 'refs/heads/release'))
- script: |
tar czf primaite_v$(VERSION)_benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)
displayName: 'Prepare Artifacts for Publishing'
- script: |
tar czf primaite_v$(VERSION)_benchmark.tar.gz benchmark/results/v$(MAJOR_VERSION)/v$(VERSION)
displayName: 'Prepare Artifacts for Publishing'
- task: PublishPipelineArtifact@1
inputs:
targetPath: primaite_v$(VERSION)_benchmark.tar.gz
artifactName: 'benchmark-output'
publishLocation: 'pipeline'
displayName: 'Publish Benchmark Output as Artifact'
- task: PublishPipelineArtifact@1
inputs:
targetPath: primaite_v$(VERSION)_benchmark.tar.gz
artifactName: 'benchmark-output'
publishLocation: 'pipeline'
displayName: 'Publish Benchmark Output as Artifact'

View File

@@ -2,9 +2,31 @@
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Added
- **show_bandwidth_load Function**: Displays current bandwidth load for each frequency in the airspace.
- **Bandwidth Tracking**: Tracks data transmission across each frequency.
- **New Tests**: Added to validate the respect of bandwidth capacities and the correct parsing of airspace configurations from YAML files.
- **New Logging**: Added a new agent behaviour log which are more human friendly than agent history. These Logs are found in session log directory and can be enabled in the I/O settings in a yaml configuration file.
### Changed
- **NetworkInterface Speed Type**: The `speed` attribute of `NetworkInterface` has been changed from `int` to `float`.
- **Transmission Feasibility Check**: Updated `_can_transmit` function in `Link` to account for current load and total bandwidth capacity, ensuring transmissions do not exceed limits.
- **Frame Size Details**: Frame `size` attribute now includes both core size and payload size in bytes.
- **Transmission Blocking**: Enhanced `AirSpace` logic to block transmissions that would exceed the available capacity.
### Fixed
- **Transmission Permission Logic**: Corrected the logic in `can_transmit_frame` to accurately prevent overloads by checking if the transmission of a frame stays within allowable bandwidth limits after considering current load.
[//]: # (This file needs tidying up between 2.0.0 and this line as it hasn't been segmented into 3.0.0 and 3.1.0 and isn't compliant with https://keepachangelog.com/en/1.1.0/)
## 3.0.0b9
- Removed deprecated `PrimaiteSession` class.
- Added ability to set log levels via configuration.
@@ -26,8 +48,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization.
- Added notebook to demonstrate use of SubprocVecEnv from SB3 to vectorise environments to speed up training.
## [Unreleased]
- Made requests fail to reach their target if the node is off
- Added responses to requests

View File

@@ -26,7 +26,7 @@
"av_s_per_session": 3205.6340542,
"av_s_per_step": 0.10017606419375,
"av_s_per_100_steps_10_nodes": 10.017606419375,
"combined_av_reward_per_episode": {
"combined_total_reward_per_episode": {
"1": -53.42999999999999,
"2": -25.18000000000001,
"3": -42.00000000000002,

Binary file not shown.

After

Width:  |  Height:  |  Size: 322 KiB

File diff suppressed because it is too large Load Diff

View File

@@ -123,6 +123,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/environment
source/customising_scenarios
source/varying_config_files
source/action_masking
.. toctree::
:caption: Notebooks:

View File

@@ -0,0 +1,80 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
Action Masking
**************
The PrimAITE simulation is able to provide action masks in the environment output. These action masks let the agents know
about which actions are invalid based on the current environment state. For instance, it's not possible to install
software on a node that is turned off. Therefore, if an agent has a NODE_SOFTWARE_INSTALL in it's action map for that node,
the action mask will show `0` in the corresponding entry.
Configuration
=============
Action masking is supported for agents that use the `ProxyAgent` class (the class used for connecting to RL algorithms).
In order to use action masking, set the agent_settings.action_masking parameter to True in the config file.
Masking Logic
=============
The following logic is applied:
* **DONOTHING** : Always possible
* **NODE_HOST_SERVICE_SCAN** : Node is on. Service is running.
* **NODE_HOST_SERVICE_STOP** : Node is on. Service is running.
* **NODE_HOST_SERVICE_START** : Node is on. Service is stopped.
* **NODE_HOST_SERVICE_PAUSE** : Node is on. Service is running.
* **NODE_HOST_SERVICE_RESUME** : Node is on. Service is paused.
* **NODE_HOST_SERVICE_RESTART** : Node is on. Service is running.
* **NODE_HOST_SERVICE_DISABLE** : Node is on.
* **NODE_HOST_SERVICE_ENABLE** : Node is on. Service is disabled.
* **NODE_HOST_SERVICE_FIX** : Node is on. Service is running.
* **NODE_HOST_APPLICATION_EXECUTE** : Node is on.
* **NODE_HOST_APPLICATION_SCAN** : Node is on. Application is running.
* **NODE_HOST_APPLICATION_CLOSE** : Node is on. Application is running.
* **NODE_HOST_APPLICATION_FIX** : Node is on. Application is running.
* **NODE_HOST_APPLICATION_INSTALL** : Node is on.
* **NODE_HOST_APPLICATION_REMOVE** : Node is on.
* **NODE_HOST_FILE_SCAN** : Node is on. File exists. File not deleted.
* **NODE_HOST_FILE_CREATE** : Node is on.
* **NODE_HOST_FILE_CHECKHASH** : Node is on. File exists. File not deleted.
* **NODE_HOST_FILE_DELETE** : Node is on. File exists.
* **NODE_HOST_FILE_REPAIR** : Node is on. File exists. File not deleted.
* **NODE_HOST_FILE_RESTORE** : Node is on. File exists. File is deleted.
* **NODE_HOST_FILE_CORRUPT** : Node is on. File exists. File not deleted.
* **NODE_HOST_FILE_ACCESS** : Node is on. File exists. File not deleted.
* **NODE_HOST_FOLDER_CREATE** : Node is on.
* **NODE_HOST_FOLDER_SCAN** : Node is on. Folder exists. Folder not deleted.
* **NODE_HOST_FOLDER_CHECKHASH** : Node is on. Folder exists. Folder not deleted.
* **NODE_HOST_FOLDER_REPAIR** : Node is on. Folder exists. Folder not deleted.
* **NODE_HOST_FOLDER_RESTORE** : Node is on. Folder exists. Folder is deleted.
* **NODE_HOST_OS_SCAN** : Node is on.
* **NODE_HOST_NIC_ENABLE** : NIC is disabled. Node is on.
* **NODE_HOST_NIC_DISABLE** : NIC is enabled. Node is on.
* **NODE_HOST_SHUTDOWN** : Node is on.
* **NODE_HOST_STARTUP** : Node is off.
* **NODE_HOST_RESET** : Node is on.
* **NODE_HOST_NMAP_PING_SCAN** : Node is on.
* **NODE_HOST_NMAP_PORT_SCAN** : Node is on.
* **NODE_HOST_NMAP_NETWORK_SERVICE_RECON** : Node is on.
* **NODE_ROUTER_PORT_ENABLE** : Router is on.
* **NODE_ROUTER_PORT_DISABLE** : Router is on.
* **NODE_ROUTER_ACL_ADDRULE** : Router is on.
* **NODE_ROUTER_ACL_REMOVERULE** : Router is on.
* **NODE_FIREWALL_PORT_ENABLE** : Firewall is on.
* **NODE_FIREWALL_PORT_DISABLE** : Firewall is on.
* **NODE_FIREWALL_ACL_ADDRULE** : Firewall is on.
* **NODE_FIREWALL_ACL_REMOVERULE** : Firewall is on.
Mechanism
=========
The environment iterates over the RL agent's ``action_map`` and generates the corresponding simulator request string.
It uses the ``RequestManager.check_valid()`` method to invoke the relevant ``RequestPermissionValidator`` without
actually running the request on the simulation.
Current Limitations
===================
Currently, action masking only considers whether the action as a whole is possible, it doesn't verify that the exact
parameter combination passed to the action make sense in the current context. For instance, if ACL rule 3 on router_1 is
already populated, the action for adding another rule at position 3 will be available regardless, as long as that router
is turned on. This will never block valid actions. It will just occasionally allow invalid actions.

View File

@@ -18,8 +18,11 @@ This section configures how PrimAITE saves data during simulation and training.
save_step_metadata: False
save_pcap_logs: False
save_sys_logs: False
save_agent_logs: False
write_sys_log_to_terminal: False
write_agent_log_to_terminal: False
sys_log_level: WARNING
agent_log_level: INFO
``save_logs``
@@ -57,6 +60,12 @@ Optional. Default value is ``False``.
If ``True``, then the log files which contain all node actions during the simulation will be saved.
``save_agent_logs``
-----------------
Optional. Default value is ``False``.
If ``True``, then the log files which contain all human readable agent behaviour during the simulation will be saved.
``write_sys_log_to_terminal``
-----------------------------
@@ -65,16 +74,25 @@ Optional. Default value is ``False``.
If ``True``, PrimAITE will print sys log to the terminal.
``write_agent_log_to_terminal``
-----------------------------
``sys_log_level``
-------------
Optional. Default value is ``False``.
If ``True``, PrimAITE will print all human readable agent behaviour logs to the terminal.
``sys_log_level & agent_log_level``
---------------------------------
Optional. Default value is ``WARNING``.
The level of logging that should be visible in the sys logs or the logs output to the terminal.
The level of logging that should be visible in the syslog, agent logs or the logs output to the terminal.
``save_sys_logs`` or ``write_sys_log_to_terminal`` has to be set to ``True`` for this setting to be used.
This is also true for agent behaviour logging.
Available options are:
- ``DEBUG``: Debug level items and the items below

View File

@@ -7,7 +7,7 @@
==============
In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents.
At the top level of the network are ``nodes`` and ``links``.
At the top level of the network are ``nodes``, ``links`` and ``airspace``.
e.g.
@@ -19,6 +19,9 @@ e.g.
...
links:
...
airspace:
...
``nodes``
---------
@@ -101,3 +104,27 @@ This accepts an integer value e.g. if port 1 is to be connected, the configurati
``bandwidth``
This is an integer value specifying the allowed bandwidth across the connection. Units are in Mbps.
``airspace``
------------
This section configures settings specific to the wireless network's virtual airspace.
``frequency_max_capacity_mbps``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
This setting allows the user to override the default maximum bandwidth capacity set for each frequency. The key should
be the AirSpaceFrequency name and the value be the desired maximum bandwidth capacity in mbps (megabits per second) for
a single timestep.
The below example would permit 123.45 megabits to be transmit across the WiFi 2.4 GHz frequency in a single timestep.
Setting a frequencies max capacity to 0.0 blocks that frequency on the airspace.
.. code-block:: yaml
simulation:
network:
airspace:
frequency_max_capacity_mbps:
WIFI_2_4: 123.45
WIFI_5: 0.0

View File

@@ -27,6 +27,7 @@ Contents
simulation_components/network/nodes/firewall
simulation_components/network/switch
simulation_components/network/network
simulation_components/network/airspace
simulation_components/system/internal_frame_processing
simulation_components/system/sys_log
simulation_components/system/pcap

View File

@@ -0,0 +1,42 @@
.. only:: comment
© Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
.. _airspace:
AirSpace
========
1. Introduction
---------------
The AirSpace class is the central component for wireless networks in PrimAITE and is designed to model and manage the behavior and interactions of wireless network interfaces within a simulated wireless network environment. This documentation provides a detailed overview of the AirSpace class, its components, and how they interact to create a realistic simulation of wireless network dynamics.
2. Overview of the AirSpace System
----------------------------------
The AirSpace is a virtual representation of a physical wireless environment, managing multiple wireless network interfaces that simulate devices connected to the wireless network. These interfaces communicate over radio frequencies, with their interactions influenced by various factors modeled within the AirSpace.
2.1 Key Components
^^^^^^^^^^^^^^^^^^
- **Wireless Network Interfaces**: Representations of network interfaces connected physical devices like routers, computers, or IoT devices that can send and receive data wirelessly.
- **Bandwidth Management**: Tracks data transmission over frequencies to prevent overloading and simulate real-world network congestion.
3. Managing Wireless Network Interfaces
---------------------------------------
- Interfaces can be dynamically added or removed.
- Configurations can be changed in real-time.
- The AirSpace handles data transmissions, ensuring data sent by an interface is received by all other interfaces on the same frequency.
4. AirSpace Inspection
----------------------
The AirSpace class provides methods for visualizing network behavior:
- ``show_wireless_interfaces()``: Displays current state of all interfaces
- ``show_bandwidth_load()``: Shows bandwidth utilisation

View File

@@ -37,7 +37,7 @@ additional steps to configure wireless settings:
.. code-block:: python
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from primaite.simulator.network.airspace import AirSpaceFrequency
from primaite.simulator.network.airspace import AirSpaceFrequency, ChannelWidth
# Instantiate the WirelessRouter
wireless_router = WirelessRouter(hostname="MyWirelessRouter")
@@ -49,7 +49,7 @@ additional steps to configure wireless settings:
wireless_router.configure_wireless_access_point(
port=1, ip_address="192.168.2.1",
subnet_mask="255.255.255.0",
frequency=AirSpaceFrequency.WIFI_2_4
frequency=AirSpaceFrequency.WIFI_2_4,
)
@@ -71,7 +71,7 @@ ICMP traffic, ensuring basic network connectivity and ping functionality.
.. code-block:: python
from primaite.simulator.network.airspace import AIR_SPACE, AirSpaceFrequency
from primaite.simulator.network.airspace import AirSpaceFrequency, ChannelWidth
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
@@ -130,13 +130,13 @@ ICMP traffic, ensuring basic network connectivity and ping functionality.
port=1,
ip_address="192.168.1.1",
subnet_mask="255.255.255.0",
frequency=AirSpaceFrequency.WIFI_2_4
frequency=AirSpaceFrequency.WIFI_2_4,
)
router_2.configure_wireless_access_point(
port=1,
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
frequency=AirSpaceFrequency.WIFI_2_4
frequency=AirSpaceFrequency.WIFI_2_4,
)
# Configure routes for inter-router communication

View File

@@ -55,6 +55,7 @@ rl = [
"ray[rllib] >= 2.20.0, < 3",
"tensorflow==2.12.0",
"stable-baselines3[extra]==2.1.0",
"sb3-contrib==2.1.0",
]
dev = [
"build==0.10.0",

View File

@@ -741,6 +741,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true

View File

@@ -733,6 +733,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true
- ref: defender_2
team: BLUE
@@ -1316,6 +1317,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true

View File

@@ -44,3 +44,18 @@ def data_manipulation_config_path() -> Path:
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path
def data_manipulation_marl_config_path() -> Path:
"""
Get the path to the MARL example config.
:return: Path to yaml config file for the MARL scenario.
:rtype: Path
"""
path = _EXAMPLE_CFG / "data_manipulation_marl.yaml"
if not path.exists():
msg = f"Example config does not exist: {path}. Have you run `primaite setup`?"
_LOGGER.error(msg)
raise FileNotFoundError(msg)
return path

View File

@@ -49,7 +49,7 @@ class AbstractAction(ABC):
objects."""
@abstractmethod
def form_request(self) -> List[str]:
def form_request(self) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return []
@@ -67,7 +67,7 @@ class DoNothingAction(AbstractAction):
# i.e. a choice between one option. To make enumerating this action easier, we are adding a 'dummy' paramter
# with one option. This just aids the Action Manager to enumerate all possibilities.
def form_request(self, **kwargs) -> List[str]:
def form_request(self, **kwargs) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["do_nothing"]
@@ -86,7 +86,7 @@ class NodeServiceAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, service_id: int) -> List[str]:
def form_request(self, node_id: int, service_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
service_name = self.manager.get_service_name_by_idx(node_id, service_id)
@@ -181,7 +181,7 @@ class NodeApplicationAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, application_id: int) -> List[str]:
def form_request(self, node_id: int, application_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
application_name = self.manager.get_application_name_by_idx(node_id, application_id)
@@ -229,7 +229,7 @@ class NodeApplicationInstallAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str) -> List[str]:
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
@@ -324,7 +324,7 @@ class NodeApplicationRemoveAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
def form_request(self, node_id: int, application_name: str) -> List[str]:
def form_request(self, node_id: int, application_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None:
@@ -346,7 +346,7 @@ class NodeFolderAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -394,7 +394,9 @@ class NodeFileCreateAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "create"
def form_request(self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False) -> List[str]:
def form_request(
self, node_id: int, folder_name: str, file_name: str, force: Optional[bool] = False
) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None or file_name is None:
@@ -409,7 +411,7 @@ class NodeFolderCreateAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "create"
def form_request(self, node_id: int, folder_name: str) -> List[str]:
def form_request(self, node_id: int, folder_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None:
@@ -430,7 +432,7 @@ class NodeFileAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -463,7 +465,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb: str = "delete"
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
def form_request(self, node_id: int, folder_id: int, file_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
folder_name = self.manager.get_folder_name_by_idx(node_idx=node_id, folder_idx=folder_id)
@@ -504,7 +506,7 @@ class NodeFileAccessAction(AbstractAction):
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, **kwargs)
self.verb: str = "access"
def form_request(self, node_id: int, folder_name: str, file_name: str) -> List[str]:
def form_request(self, node_id: int, folder_name: str, file_name: str) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
if node_name is None or folder_name is None or file_name is None:
@@ -525,7 +527,7 @@ class NodeAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int) -> List[str]:
def form_request(self, node_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_id)
return ["network", "node", node_name, self.verb]
@@ -740,7 +742,7 @@ class RouterACLRemoveRuleAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"position": max_acl_rules}
def form_request(self, target_router: str, position: int) -> List[str]:
def form_request(self, target_router: str, position: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
return ["network", "node", target_router, "acl", "remove_rule", position]
@@ -923,7 +925,7 @@ class HostNICAbstractAction(AbstractAction):
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, nic_id: int) -> List[str]:
def form_request(self, node_id: int, nic_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_name = self.manager.get_node_name_by_idx(node_idx=node_id)
nic_num = self.manager.get_nic_num_by_idx(node_idx=node_id, nic_idx=nic_id)
@@ -960,7 +962,7 @@ class NetworkPortEnableAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if target_nodename is None or port_id is None:
return ["do_nothing"]
@@ -979,7 +981,7 @@ class NetworkPortDisableAction(AbstractAction):
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"port_id": max_nics_per_node}
def form_request(self, target_nodename: str, port_id: int) -> List[str]:
def form_request(self, target_nodename: str, port_id: int) -> RequestFormat:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
if target_nodename is None or port_id is None:
return ["do_nothing"]
@@ -1315,7 +1317,7 @@ class ActionManager:
act_identifier, act_options = self.action_map[action]
return act_identifier, act_options
def form_request(self, action_identifier: str, action_options: Dict) -> List[str]:
def form_request(self, action_identifier: str, action_options: Dict) -> RequestFormat:
"""Take action in CAOS format and use the execution definition to change it into PrimAITE request format."""
act_obj = self.actions[action_identifier]
return act_obj.form_request(**action_options)

View File

@@ -0,0 +1,188 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import logging
from pathlib import Path
from prettytable import MARKDOWN, PrettyTable
from primaite.simulator import LogLevel, SIM_OUTPUT
class _NotJSONFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
"""
Determines if a log message does not start and end with '{' and '}' (i.e., it is not a JSON-like message).
:param record: LogRecord object containing all the information pertinent to the event being logged.
:return: True if log message is not JSON-like, False otherwise.
"""
return not record.getMessage().startswith("{") and not record.getMessage().endswith("}")
class AgentLog:
"""
A Agent Log class is a simple logger dedicated to managing and writing logging updates and information for an agent.
Each log message is written to a file located at: <simulation output directory>/agent_name/agent_name.log
"""
def __init__(self, agent_name: str):
"""
Constructs a Agent Log instance for a given hostname.
:param hostname: The hostname associated with the system logs being recorded.
"""
self.agent_name = agent_name
self.current_episode: int = 1
self.current_timestep: int = 0
self.setup_logger()
@property
def timestep(self) -> int:
"""Returns the current timestep. Used for log indexing.
:return: The current timestep as an Int.
"""
return self.current_timestep
def update_timestep(self, new_timestep: int):
"""
Updates the self.current_timestep attribute with the given parameter.
This method is called within .step() to ensure that all instances of Agent Logs
are in sync with one another.
:param new_timestep: The new timestep.
"""
self.current_timestep = new_timestep
def setup_logger(self):
"""
Configures the logger for this Agent Log instance.
The logger is set to the DEBUG level, and is equipped with a handler that writes to a file and filters out
JSON-like messages.
"""
if not SIM_OUTPUT.save_agent_logs:
return
log_path = self._get_log_path()
file_handler = logging.FileHandler(filename=log_path)
file_handler.setLevel(logging.DEBUG)
log_format = "%(timestep)s::%(levelname)s::%(message)s"
file_handler.setFormatter(logging.Formatter(log_format))
self.logger = logging.getLogger(f"{self.agent_name}_log")
for handler in self.logger.handlers:
self.logger.removeHandler(handler)
self.logger.setLevel(logging.DEBUG)
self.logger.addHandler(file_handler)
def _get_log_path(self) -> Path:
"""
Constructs the path for the log file based on the agent name.
:return: Path object representing the location of the log file.
"""
root = SIM_OUTPUT.agent_behaviour_path / f"episode_{self.current_episode}" / self.agent_name
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.agent_name}.log"
def _write_to_terminal(self, msg: str, level: str, to_terminal: bool = False):
if to_terminal or SIM_OUTPUT.write_agent_log_to_terminal:
print(f"{self.agent_name}: ({ self.timestep}) ({level}) {msg}")
def debug(self, msg: str, to_terminal: bool = False):
"""
Logs a message with the DEBUG level.
:param msg: The message to be logged.
:param to_terminal: If True, prints to the terminal too.
"""
if SIM_OUTPUT.agent_log_level > LogLevel.DEBUG:
return
if SIM_OUTPUT.save_agent_logs:
self.logger.debug(msg, extra={"timestep": self.timestep})
self._write_to_terminal(msg, "DEBUG", to_terminal)
def info(self, msg: str, to_terminal: bool = False):
"""
Logs a message with the INFO level.
:param msg: The message to be logged.
:param timestep: The current timestep.
:param to_terminal: If True, prints to the terminal too.
"""
if SIM_OUTPUT.agent_log_level > LogLevel.INFO:
return
if SIM_OUTPUT.save_agent_logs:
self.logger.info(msg, extra={"timestep": self.timestep})
self._write_to_terminal(msg, "INFO", to_terminal)
def warning(self, msg: str, to_terminal: bool = False):
"""
Logs a message with the WARNING level.
:param msg: The message to be logged.
:param timestep: The current timestep.
:param to_terminal: If True, prints to the terminal too.
"""
if SIM_OUTPUT.agent_log_level > LogLevel.WARNING:
return
if SIM_OUTPUT.save_agent_logs:
self.logger.warning(msg, extra={"timestep": self.timestep})
self._write_to_terminal(msg, "WARNING", to_terminal)
def error(self, msg: str, to_terminal: bool = False):
"""
Logs a message with the ERROR level.
:param msg: The message to be logged.
:param timestep: The current timestep.
:param to_terminal: If True, prints to the terminal too.
"""
if SIM_OUTPUT.agent_log_level > LogLevel.ERROR:
return
if SIM_OUTPUT.save_agent_logs:
self.logger.error(msg, extra={"timestep": self.timestep})
self._write_to_terminal(msg, "ERROR", to_terminal)
def critical(self, msg: str, to_terminal: bool = False):
"""
Logs a message with the CRITICAL level.
:param msg: The message to be logged.
:param timestep: The current timestep.
:param to_terminal: If True, prints to the terminal too.
"""
if LogLevel.CRITICAL < SIM_OUTPUT.agent_log_level:
return
if SIM_OUTPUT.save_agent_logs:
self.logger.critical(msg, extra={"timestep": self.timestep})
self._write_to_terminal(msg, "CRITICAL", to_terminal)
def show(self, last_n: int = 10, markdown: bool = False):
"""
Print an Agents Log as a table.
Generate and print PrettyTable instance that shows the agents behaviour log, with columns Time step,
Level and Message.
:param markdown: Use Markdown style in table output. Defaults to False.
"""
table = PrettyTable(["Time Step", "Level", "Message"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.agent_name} Behaviour Log"
if self._get_log_path().exists():
with open(self._get_log_path()) as file:
lines = file.readlines()
for line in lines[-last_n:]:
table.add_row(line.strip().split("::"))
print(table)

View File

@@ -7,6 +7,7 @@ from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.agent_log import AgentLog
from primaite.game.agent.observations.observation_manager import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.interface.request import RequestFormat, RequestResponse
@@ -69,6 +70,8 @@ class AgentSettings(BaseModel):
"Configuration for when an agent begins performing it's actions"
flatten_obs: bool = True
"Whether to flatten the observation space before passing it to the agent. True by default."
action_masking: bool = False
"Whether to return action masks at each step."
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
@@ -116,6 +119,7 @@ class AbstractAgent(ABC):
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.history: List[AgentHistoryItem] = []
self.logger = AgentLog(agent_name)
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -205,6 +209,7 @@ class ProxyAgent(AbstractAgent):
)
self.most_recent_action: ActType
self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False
self.action_masking: bool = agent_settings.action_masking if agent_settings else False
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
"""

View File

@@ -38,10 +38,11 @@ class DataManipulationAgent(AbstractScriptedAgent):
:rtype: Tuple[str, Dict]
"""
if timestep < self.next_execution_timestep:
self.logger.debug(msg="Performing do NOTHING")
return "DONOTHING", {}
self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency)
self.logger.info(msg="Performing a data manipulation attack!")
return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0}
def setup_agent(self) -> None:
@@ -54,3 +55,4 @@ class DataManipulationAgent(AbstractScriptedAgent):
# we are assuming that every node in the node manager has a data manipulation application at idx 0
num_nodes = len(self.action_manager.node_names)
self.starting_node_idx = random.randint(0, num_nodes - 1)
self.logger.debug(msg=f"Select Start Node ID: {self.starting_node_idx}")

View File

@@ -85,4 +85,5 @@ class ProbabilisticAgent(AbstractScriptedAgent):
:rtype: Tuple[str, Dict]
"""
choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities)
self.logger.info(f"Performing Action: {choice}")
return self.action_manager.get_action(choice)

View File

@@ -3,6 +3,7 @@
from ipaddress import IPv4Address
from typing import Dict, List, Optional
import numpy as np
from pydantic import BaseModel, ConfigDict
from primaite import DEFAULT_BANDWIDTH, getLogger
@@ -15,7 +16,9 @@ from primaite.game.agent.scripted_agents.probabilistic_agent import Probabilisti
from primaite.game.agent.scripted_agents.random_agent import PeriodicAgent
from primaite.game.agent.scripted_agents.tap001 import TAP001
from primaite.game.science import graph_has_cycle, topological_sort
from primaite.simulator.network.hardware.base import NetworkInterface, NodeOperatingState
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.network.airspace import AirSpaceFrequency
from primaite.simulator.network.hardware.base import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.host_node import NIC
from primaite.simulator.network.hardware.nodes.host.server import Printer, Server
@@ -167,6 +170,8 @@ class PrimaiteGame:
for _, agent in self.agents.items():
obs = agent.observation_manager.current_observation
action_choice, parameters = agent.get_action(obs, timestep=self.step_counter)
if SIM_OUTPUT.save_agent_logs:
agent.logger.debug(f"Chosen Action: {action_choice}")
request = agent.format_request(action_choice, parameters)
response = self.simulation.apply_request(request)
agent.process_action_response(
@@ -185,8 +190,14 @@ class PrimaiteGame:
"""Advance timestep."""
self.step_counter += 1
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
self.update_agent_loggers()
self.simulation.apply_timestep(self.step_counter)
def update_agent_loggers(self) -> None:
"""Updates Agent Loggers with new timestep."""
for agent in self.agents.values():
agent.logger.update_timestep(self.step_counter)
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
@@ -195,6 +206,23 @@ class PrimaiteGame:
return True
return False
def action_mask(self, agent_name: str) -> np.ndarray:
"""
Return the action mask for the agent.
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
performed during this step.
:return: Action mask
:rtype: List[bool]
"""
agent = self.agents[agent_name]
mask = [True] * len(agent.action_manager.action_map)
for i, action in agent.action_manager.action_map.items():
request = agent.action_manager.form_request(action_identifier=action[0], action_options=action[1])
mask[i] = self.simulation._request_manager.check_valid(request, {})
return np.asarray(mask, dtype=np.int8)
def close(self) -> None:
"""Close the game, this will close the simulation."""
return NotImplemented
@@ -230,6 +258,12 @@ class PrimaiteGame:
simulation_config = cfg.get("simulation", {})
network_config = simulation_config.get("network", {})
airspace_cfg = network_config.get("airspace", {})
frequency_max_capacity_mbps_cfg = airspace_cfg.get("frequency_max_capacity_mbps", {})
frequency_max_capacity_mbps_cfg = {AirSpaceFrequency[k]: v for k, v in frequency_max_capacity_mbps_cfg.items()}
net.airspace.frequency_max_capacity_mbps_ = frequency_max_capacity_mbps_cfg
nodes_cfg = network_config.get("nodes", [])
links_cfg = network_config.get("links", [])

View File

@@ -0,0 +1,218 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Action Masking\n",
"\n",
"PrimAITE environments support action masking. The action mask shows which of the agent's actions are applicable with the current environment state. For example, a node can only be turned on if it is currently turned off."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from prettytable import PrettyTable\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env = PrimaiteGymEnv(data_manipulation_config_path())\n",
"env.action_masking = True"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The action mask is a list of booleans that specifies whether each action in the agent's action map is currently possible. Demonstrated here:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"act_table = PrettyTable((\"number\", \"action\", \"parameters\", \"mask\"))\n",
"mask = env.action_masks()\n",
"actions = env.agent.action_manager.action_map\n",
"max_str_len = 70\n",
"for act,mask in zip(actions.items(), mask):\n",
" act_num, act_data = act\n",
" act_type, act_params = act_data\n",
" act_params = s if len(s:=str(act_params))<max_str_len else f\"{s[:max_str_len-3]}...\"\n",
" act_table.add_row((act_num, act_type, act_params, mask))\n",
"print(act_table)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Action masking for Stable Baselines3 agents\n",
"SB3 agents automatically use the action_masks method during the training loop"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sb3_contrib import MaskablePPO\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = MaskablePPO(\"MlpPolicy\", env, gamma=0.4, seed=32)\n",
"model.learn(1024)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Action masking for Ray RLLib agents\n",
"Ray uses a different API to obtain action masks, but this is handled by the PrimaiteRayEnv and PrimaiteRayMarlEnv classes"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"import yaml\n",
"from ray import air, tune\n",
"from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule\n",
"from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"for agent in cfg['agents']:\n",
" if agent[\"ref\"] == \"defender\":\n",
" agent['agent_settings']['flatten_obs'] = True\n",
"env_config = cfg\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = (\n",
" PPOConfig()\n",
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
" .environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
" .rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class = ActionMaskingTorchRLModule))\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Action masking with MARL in Ray RLLib\n",
"Each agent has their own action mask, this is useful if the agents have different action spaces."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"from primaite.config.load import data_manipulation_marl_config_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_marl_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"env_config = cfg\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = (\n",
" PPOConfig()\n",
" .multi_agent(\n",
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,\n",
" )\n",
" .api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key=\"action_mask\")\n",
" .rl_module(rl_module_spec=MultiAgentRLModuleSpec(module_specs={\n",
" \"defender_1\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
" \"defender_2\":SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),\n",
" }))\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")\n",
"algo = config.build()\n",
"for i in range(2):\n",
" results = algo.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -4,6 +4,7 @@ from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
import gymnasium
import numpy as np
from gymnasium.core import ActType, ObsType
from primaite import getLogger
@@ -41,6 +42,21 @@ class PrimaiteGymEnv(gymnasium.Env):
self.total_reward_per_episode: Dict[int, float] = {}
"""Average rewards of agents per episode."""
def action_masks(self) -> np.ndarray:
"""
Return the action mask for the agent.
This is a boolean list corresponding to the agent's action space. A False entry means this action cannot be
performed during this step.
:return: Action mask
:rtype: List[bool]
"""
if not self.agent.action_masking:
return np.asarray([True] * len(self.agent.action_manager.action_map))
else:
return self.game.action_mask(self._agent_name)
@property
def agent(self) -> ProxyAgent:
"""Grab a fresh reference to the agent object because it will be reinstantiated each episode."""

View File

@@ -35,10 +35,16 @@ class PrimaiteIO:
"""Whether to save PCAP logs."""
save_sys_logs: bool = True
"""Whether to save system logs."""
save_agent_logs: bool = True
"""Whether to save agent logs."""
write_sys_log_to_terminal: bool = False
"""Whether to write the sys log to the terminal."""
write_agent_log_to_terminal: bool = False
"""Whether to write the agent log to the terminal."""
sys_log_level: LogLevel = LogLevel.INFO
"""The level of log that should be included in the logfiles/logged into terminal."""
"""The level of sys logs that should be included in the logfiles/logged into terminal."""
agent_log_level: LogLevel = LogLevel.INFO
"""The level of agent logs that should be included in the logfiles/logged into terminal."""
def __init__(self, settings: Optional[Settings] = None) -> None:
"""
@@ -51,27 +57,31 @@ class PrimaiteIO:
self.session_path: Path = self.generate_session_path()
# set global SIM_OUTPUT path
SIM_OUTPUT.path = self.session_path / "simulation_output"
SIM_OUTPUT.agent_behaviour_path = self.session_path / "agent_behaviour"
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
SIM_OUTPUT.save_agent_logs = self.settings.save_agent_logs
SIM_OUTPUT.write_agent_log_to_terminal = self.settings.write_agent_log_to_terminal
SIM_OUTPUT.write_sys_log_to_terminal = self.settings.write_sys_log_to_terminal
SIM_OUTPUT.sys_log_level = self.settings.sys_log_level
SIM_OUTPUT.agent_log_level = self.settings.agent_log_level
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
if timestamp is None:
timestamp = datetime.now()
date_str = timestamp.strftime("%Y-%m-%d")
time_str = timestamp.strftime("%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
session_path = PRIMAITE_PATHS.user_sessions_path / SIM_OUTPUT.date_str / SIM_OUTPUT.time_str
# check if running in dev mode
if is_dev_mode():
session_path = _PRIMAITE_ROOT.parent.parent / "sessions" / date_str / time_str
session_path = _PRIMAITE_ROOT.parent.parent / "sessions" / SIM_OUTPUT.date_str / SIM_OUTPUT.time_str
# check if there is an output directory set in config
if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
session_path = Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"]) / "sessions" / date_str / time_str
session_path = (
Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"])
/ "sessions"
/ SIM_OUTPUT.date_str
/ SIM_OUTPUT.time_str
)
session_path.mkdir(exist_ok=True, parents=True)
return session_path
@@ -115,6 +125,9 @@ class PrimaiteIO:
if config.get("sys_log_level"):
config["sys_log_level"] = LogLevel[config["sys_log_level"].upper()] # convert to enum
if config.get("agent_log_level"):
config["agent_log_level"] = LogLevel[config["agent_log_level"].upper()] # convert to enum
new = cls(settings=cls.Settings(**config))
return new

View File

@@ -3,6 +3,7 @@ import json
from typing import Dict, SupportsFloat, Tuple
import gymnasium
from gymnasium import spaces
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
@@ -38,15 +39,19 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
self.observation_space = spaces.Dict(
{name: spaces.flatten_space(agent.observation_manager.space) for name, agent in self.agents.items()}
)
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
if agent.action_masking:
self.observation_space[agent_name] = spaces.Dict(
{
"action_mask": spaces.MultiBinary(agent.action_manager.space.n),
"observations": self.observation_space[agent_name],
}
)
self.action_space = spaces.Dict({name: agent.action_manager.space for name, agent in self.agents.items()})
self._obs_space_in_preferred_format = True
self._action_space_in_preferred_format = True
super().__init__()
@@ -131,13 +136,17 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
all_obs = {}
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
obs = gymnasium.spaces.flatten(unflat_space, unflat_obs)
if agent.action_masking:
all_obs[agent_name] = {"action_mask": self.game.action_mask(agent_name), "observations": obs}
else:
all_obs[agent_name] = obs
return all_obs
def close(self):
"""Close the simulation."""
@@ -158,15 +167,30 @@ class PrimaiteRayEnv(gymnasium.Env):
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
if self.env.agent.action_masking:
self.observation_space = spaces.Dict(
{"action_mask": spaces.MultiBinary(self.env.action_space.n), "observations": self.env.observation_space}
)
else:
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
if self.env.agent.action_masking:
obs, *_ = self.env.reset(seed=seed)
new_obs = {"action_mask": self.env.action_masks(), "observations": obs}
return new_obs, *_
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
# if action masking is enabled, intercept the step method and add action mask to observation
if self.env.agent.action_masking:
obs, *_ = self.env.step(action)
new_obs = {"action_mask": self.game.action_mask(self.env._agent_name), "observations": obs}
return new_obs, *_
else:
return self.env.step(action)
def close(self):
"""Close the simulation."""

View File

@@ -3,6 +3,8 @@
developer_mode:
enabled: False # not enabled by default
sys_log_level: DEBUG # level of output for system logs, DEBUG by default
agent_log_level: DEBUG # level of output for agent logs, DEBUG by default
output_agent_logs: False # level of output for system logs, DEBUG by default
output_sys_logs: False # system logs not output by default
output_pcap_logs: False # pcap logs not output by default
output_to_terminal: False # do not output to terminal by default

View File

@@ -34,10 +34,14 @@ class _SimOutput:
path = PRIMAITE_PATHS.user_sessions_path / self.date_str / self.time_str
self._path = path
self._agent_behaviour_path = path
self._save_pcap_logs: bool = False
self._save_sys_logs: bool = False
self._save_agent_logs: bool = False
self._write_sys_log_to_terminal: bool = False
self._write_agent_log_to_terminal: bool = False
self._sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
self._agent_log_level: LogLevel = LogLevel.WARNING
@property
def path(self) -> Path:
@@ -61,6 +65,28 @@ class _SimOutput:
self._path = new_path
self._path.mkdir(exist_ok=True, parents=True)
@property
def agent_behaviour_path(self) -> Path:
if is_dev_mode():
# if dev mode is enabled, if output dir is not set, print to primaite repo root
path: Path = _PRIMAITE_ROOT.parent.parent / "sessions" / self.date_str / self.time_str / "agent_behaviour"
# otherwise print to output dir
if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
path: Path = (
Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"])
/ "sessions"
/ self.date_str
/ self.time_str
/ "agent_behaviour"
)
self._agent_behaviour_path = path
return self._agent_behaviour_path
@agent_behaviour_path.setter
def agent_behaviour_path(self, new_path: Path) -> None:
self._agent_behaviour_path = new_path
self._agent_behaviour_path.mkdir(exist_ok=True, parents=True)
@property
def save_pcap_logs(self) -> bool:
if is_dev_mode():
@@ -81,6 +107,16 @@ class _SimOutput:
def save_sys_logs(self, save_sys_logs: bool) -> None:
self._save_sys_logs = save_sys_logs
@property
def save_agent_logs(self) -> bool:
if is_dev_mode():
return PRIMAITE_CONFIG.get("developer_mode").get("output_agent_logs")
return self._save_agent_logs
@save_agent_logs.setter
def save_agent_logs(self, save_agent_logs: bool) -> None:
self._save_agent_logs = save_agent_logs
@property
def write_sys_log_to_terminal(self) -> bool:
if is_dev_mode():
@@ -91,6 +127,17 @@ class _SimOutput:
def write_sys_log_to_terminal(self, write_sys_log_to_terminal: bool) -> None:
self._write_sys_log_to_terminal = write_sys_log_to_terminal
# Should this be separate from sys_log?
@property
def write_agent_log_to_terminal(self) -> bool:
if is_dev_mode():
return PRIMAITE_CONFIG.get("developer_mode").get("output_to_terminal")
return self._write_agent_log_to_terminal
@write_agent_log_to_terminal.setter
def write_agent_log_to_terminal(self, write_agent_log_to_terminal: bool) -> None:
self._write_agent_log_to_terminal = write_agent_log_to_terminal
@property
def sys_log_level(self) -> LogLevel:
if is_dev_mode():
@@ -101,5 +148,15 @@ class _SimOutput:
def sys_log_level(self, sys_log_level: LogLevel) -> None:
self._sys_log_level = sys_log_level
@property
def agent_log_level(self) -> LogLevel:
if is_dev_mode():
return LogLevel[PRIMAITE_CONFIG.get("developer_mode").get("agent_log_level")]
return self._agent_log_level
@agent_log_level.setter
def agent_log_level(self, agent_log_level: LogLevel) -> None:
self._agent_log_level = agent_log_level
SIM_OUTPUT = _SimOutput()

View File

@@ -3,9 +3,10 @@
"""Core of the PrimAITE Simulator."""
import warnings
from abc import abstractmethod
from typing import Callable, Dict, List, Literal, Optional, Union
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
from uuid import uuid4
from prettytable import PrettyTable
from pydantic import BaseModel, ConfigDict, Field, validate_call
from primaite import getLogger
@@ -34,6 +35,20 @@ class RequestPermissionValidator(BaseModel):
"""Message that is reported when a request is rejected by this validator."""
return "request rejected"
def __add__(self, other: "RequestPermissionValidator") -> "_CombinedValidator":
return _CombinedValidator(validators=[self, other])
class _CombinedValidator(RequestPermissionValidator):
validators: List[RequestPermissionValidator] = []
def __call__(self, request, context) -> bool:
return all(x(request, context) for x in self.validators)
@property
def fail_message(self):
return f"One of the following conditions are not met: {[v.fail_message for v in self.validators]}"
class AllowAllValidator(RequestPermissionValidator):
"""Always allows the request."""
@@ -150,8 +165,17 @@ class RequestManager(BaseModel):
self.request_types.pop(name)
def get_request_types_recursively(self) -> List[List[str]]:
"""Recursively generate request tree for this component."""
def get_request_types_recursively(self) -> List[RequestFormat]:
"""
Recursively generate request tree for this component.
:param parent_valid: Whether this sub-request's parent request was valid. This value should not be specified by
users, it is used by the recursive call.
:type parent_valid: bool
:returns: A list of tuples where the first tuple element is the request string and the second is whether that
request is currently possible to execute.
:rtype: List[Tuple[RequestFormat, bool]]
"""
requests = []
for req_name, req in self.request_types.items():
if isinstance(req.func, RequestManager):
@@ -162,6 +186,30 @@ class RequestManager(BaseModel):
requests.append([req_name])
return requests
def show(self) -> None:
"""Display all currently available requests."""
table = PrettyTable(["requests"])
table.align = "l"
table.add_rows([[x] for x in self.get_request_types_recursively()])
print(table)
def check_valid(self, request: RequestFormat, context: Dict) -> bool:
"""Check if this request would be valid in the current state of the simulation without invoking it."""
request_key = request[0]
request_options = request[1:]
if request_key not in self.request_types:
return False
request_type = self.request_types[request_key]
# recurse if we are not at a leaf node
if isinstance(request_type.func, RequestManager):
return request_type.func.check_valid(request_options, context)
return request_type.validator(request_options, context)
class SimComponent(BaseModel):
"""Extension of pydantic BaseModel with additional methods that must be defined by all classes in the simulator."""

View File

@@ -52,6 +52,8 @@ class GroupMembershipValidator(RequestPermissionValidator):
def __call__(self, request: List[str], context: Dict) -> bool:
"""Permit the action if the request comes from an account which belongs to the right group."""
# if context request source is part of any groups mentioned in self.allow_groups, return true, otherwise false
if not context:
return False
requestor_groups: List[str] = context["request_source"]["groups"]
for allowed_group in self.allowed_groups:
if allowed_group.name in requestor_groups:

View File

@@ -6,8 +6,8 @@ from typing import Any, Dict, List, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType, SimComponent
from primaite.simulator.file_system.file import File
from primaite.simulator.file_system.file_type import FileType
from primaite.simulator.file_system.folder import Folder
@@ -42,6 +42,10 @@ class FileSystem(SimComponent):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
self._folder_exists = FileSystem._FolderExistsValidator(file_system=self)
self._folder_not_deleted = FileSystem._FolderNotDeletedValidator(file_system=self)
self._file_exists = FileSystem._FileExistsValidator(file_system=self)
rm = super()._init_request_manager()
self._delete_manager = RequestManager()
@@ -50,13 +54,15 @@ class FileSystem(SimComponent):
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(
self.delete_file(folder_name=request[0], file_name=request[1])
)
),
validator=self._file_exists,
),
)
self._delete_manager.add_request(
name="folder",
request_type=RequestType(
func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0]))
func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0])),
validator=self._folder_exists,
),
)
rm.add_request(
@@ -144,10 +150,13 @@ class FileSystem(SimComponent):
)
self._folder_request_manager = RequestManager()
rm.add_request("folder", RequestType(func=self._folder_request_manager))
rm.add_request(
"folder",
RequestType(func=self._folder_request_manager, validator=self._folder_exists + self._folder_not_deleted),
)
self._file_request_manager = RequestManager()
rm.add_request("file", RequestType(func=self._file_request_manager))
rm.add_request("file", RequestType(func=self._file_request_manager, validator=self._file_exists))
return rm
@@ -626,3 +635,62 @@ class FileSystem(SimComponent):
self.sys_log.error(f"Unable to access file that does not exist. (file name: {file_name})")
return False
class _FolderExistsValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the Folder exists.
Actions cannot be performed on a non-existent folder.
"""
file_system: FileSystem
"""Save a reference to the FileSystem instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Returns True if folder exists."""
return self.file_system.get_folder(folder_name=request[0]) is not None
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return "Cannot perform request on folder because it does not exist."
class _FolderNotDeletedValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the Folder has not been deleted.
Actions cannot be performed on a deleted folder.
"""
file_system: FileSystem
"""Save a reference to the FileSystem instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Returns True if folder exists and is not deleted."""
# get folder
folder = self.file_system.get_folder(folder_name=request[0], include_deleted=True)
return folder is not None and not folder.deleted
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return "Cannot perform request on folder because it is deleted."
class _FileExistsValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the File exists.
Actions cannot be performed on a non-existent file.
"""
file_system: FileSystem
"""Save a reference to the FileSystem instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Returns True if file exists."""
return self.file_system.get_file(folder_name=request[0], file_name=request[1]) is not None
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return "Cannot perform request on a file that does not exist."

View File

@@ -185,5 +185,5 @@ file_type_sizes_bytes = {
FileType.ZIP: 1024000,
FileType.TAR: 1024000,
FileType.GZ: 819200,
FileType.DB: 15360000,
FileType.DB: 5_000_000,
}

View File

@@ -6,8 +6,8 @@ from typing import Dict, Optional
from prettytable import MARKDOWN, PrettyTable
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
from primaite.simulator.file_system.file import File
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
@@ -55,6 +55,9 @@ class Folder(FileSystemItemABC):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
self._file_exists = Folder._FileExistsValidator(folder=self)
self._file_not_deleted = Folder._FileNotDeletedValidator(folder=self)
rm = super()._init_request_manager()
rm.add_request(
name="delete",
@@ -65,7 +68,9 @@ class Folder(FileSystemItemABC):
self._file_request_manager = RequestManager()
rm.add_request(
name="file",
request_type=RequestType(func=self._file_request_manager),
request_type=RequestType(
func=self._file_request_manager, validator=self._file_exists + self._file_not_deleted
),
)
return rm
@@ -469,3 +474,42 @@ class Folder(FileSystemItemABC):
self.deleted = True
return True
class _FileExistsValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the File exists.
Actions cannot be performed on a non-existent file.
"""
folder: Folder
"""Save a reference to the Folder instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Returns True if file exists."""
return self.folder.get_file(file_name=request[0]) is not None
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return "Cannot perform request on a file that does not exist."
class _FileNotDeletedValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the File is not deleted.
Actions cannot be performed on a deleted file.
"""
folder: Folder
"""Save a reference to the Folder instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Returns True if file exists and is not deleted."""
file = self.folder.get_file(file_name=request[0])
return file is not None and not file.deleted
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return "Cannot perform request on a file that is deleted."

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
from prettytable import PrettyTable
from prettytable import MARKDOWN, PrettyTable
from pydantic import BaseModel, Field
from primaite import getLogger
from primaite.simulator.network.hardware.base import Layer3Interface, NetworkInterface, WiredNetworkInterface
@@ -15,90 +16,29 @@ from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
__all__ = ["AirSpaceFrequency", "WirelessNetworkInterface", "IPWirelessNetworkInterface"]
def format_hertz(hertz: float, format_terahertz: bool = False, decimals: int = 3) -> str:
"""
Convert a frequency in Hertz to a formatted string using the most appropriate unit.
class AirSpace:
"""Represents a wireless airspace, managing wireless network interfaces and handling wireless transmission."""
Optionally includes formatting for Terahertz.
def __init__(self):
self._wireless_interfaces: Dict[str, WirelessNetworkInterface] = {}
self._wireless_interfaces_by_frequency: Dict[AirSpaceFrequency, List[WirelessNetworkInterface]] = {}
def show(self, frequency: Optional[AirSpaceFrequency] = None):
"""
Displays a summary of wireless interfaces in the airspace, optionally filtered by a specific frequency.
:param frequency: The frequency band to filter devices by. If None, devices for all frequencies are shown.
"""
table = PrettyTable()
table.field_names = ["Connected Node", "MAC Address", "IP Address", "Subnet Mask", "Frequency", "Status"]
# If a specific frequency is provided, filter by it; otherwise, use all frequencies.
frequencies_to_show = [frequency] if frequency else self._wireless_interfaces_by_frequency.keys()
for freq in frequencies_to_show:
interfaces = self._wireless_interfaces_by_frequency.get(freq, [])
for interface in interfaces:
status = "Enabled" if interface.enabled else "Disabled"
table.add_row(
[
interface._connected_node.hostname, # noqa
interface.mac_address,
interface.ip_address if hasattr(interface, "ip_address") else None,
interface.subnet_mask if hasattr(interface, "subnet_mask") else None,
str(freq),
status,
]
)
print(table)
def add_wireless_interface(self, wireless_interface: WirelessNetworkInterface):
"""
Adds a wireless network interface to the airspace if it's not already present.
:param wireless_interface: The wireless network interface to be added.
"""
if wireless_interface.mac_address not in self._wireless_interfaces:
self._wireless_interfaces[wireless_interface.mac_address] = wireless_interface
if wireless_interface.frequency not in self._wireless_interfaces_by_frequency:
self._wireless_interfaces_by_frequency[wireless_interface.frequency] = []
self._wireless_interfaces_by_frequency[wireless_interface.frequency].append(wireless_interface)
def remove_wireless_interface(self, wireless_interface: WirelessNetworkInterface):
"""
Removes a wireless network interface from the airspace if it's present.
:param wireless_interface: The wireless network interface to be removed.
"""
if wireless_interface.mac_address in self._wireless_interfaces:
self._wireless_interfaces.pop(wireless_interface.mac_address)
self._wireless_interfaces_by_frequency[wireless_interface.frequency].remove(wireless_interface)
def clear(self):
"""
Clears all wireless network interfaces and their frequency associations from the airspace.
After calling this method, the airspace will contain no wireless network interfaces, and transmissions cannot
occur until new interfaces are added again.
"""
self._wireless_interfaces.clear()
self._wireless_interfaces_by_frequency.clear()
def transmit(self, frame: Frame, sender_network_interface: WirelessNetworkInterface):
"""
Transmits a frame to all enabled wireless network interfaces on a specific frequency within the airspace.
This ensures that a wireless interface does not receive its own transmission.
:param frame: The frame to be transmitted.
:param sender_network_interface: The wireless network interface sending the frame. This interface will be
excluded from the list of receivers to prevent it from receiving its own transmission.
"""
for wireless_interface in self._wireless_interfaces_by_frequency.get(sender_network_interface.frequency, []):
if wireless_interface != sender_network_interface and wireless_interface.enabled:
wireless_interface.receive_frame(frame)
:param hertz: Frequency in Hertz.
:param format_terahertz: Whether to format frequency in Terahertz, default is False.
:param decimals: Number of decimal places to round to, default is 3.
:returns: Formatted string with the frequency in the most suitable unit.
"""
format_str = f"{{:.{decimals}f}}"
if format_terahertz and hertz >= 1e12: # Terahertz
return format_str.format(hertz / 1e12) + " THz"
elif hertz >= 1e9: # Gigahertz
return format_str.format(hertz / 1e9) + " GHz"
elif hertz >= 1e6: # Megahertz
return format_str.format(hertz / 1e6) + " MHz"
elif hertz >= 1e3: # Kilohertz
return format_str.format(hertz / 1e3) + " kHz"
else: # Hertz
return format_str.format(hertz) + " Hz"
class AirSpaceFrequency(Enum):
@@ -110,12 +50,231 @@ class AirSpaceFrequency(Enum):
"""WiFi 5 GHz. Known for its higher data transmission speeds and reduced interference from other devices."""
def __str__(self) -> str:
hertz_str = format_hertz(hertz=self.value)
if self == AirSpaceFrequency.WIFI_2_4:
return "WiFi 2.4 GHz"
elif self == AirSpaceFrequency.WIFI_5:
return "WiFi 5 GHz"
else:
return "Unknown Frequency"
return f"WiFi {hertz_str}"
if self == AirSpaceFrequency.WIFI_5:
return f"WiFi {hertz_str}"
return "Unknown Frequency"
@property
def maximum_data_rate_bps(self) -> float:
"""
Retrieves the maximum data transmission rate in bits per second (bps) for the frequency.
The maximum rates are predefined for known frequencies:
- For WIFI_2_4, it returns 100,000,000 bps (100 Mbps).
- For WIFI_5, it returns 500,000,000 bps (500 Mbps).
:return: The maximum data rate in bits per second. If the frequency is not recognized, returns 0.0.
"""
if self == AirSpaceFrequency.WIFI_2_4:
return 100_000_000.0 # 100 Megabits per second
if self == AirSpaceFrequency.WIFI_5:
return 500_000_000.0 # 500 Megabits per second
return 0.0
@property
def maximum_data_rate_mbps(self) -> float:
"""
Retrieves the maximum data transmission rate in megabits per second (Mbps).
This is derived by converting the maximum data rate from bits per second, as defined
in `maximum_data_rate_bps`, to megabits per second.
:return: The maximum data rate in megabits per second.
"""
return self.maximum_data_rate_bps / 1_000_000.0
class AirSpace(BaseModel):
"""
Represents a wireless airspace, managing wireless network interfaces and handling wireless transmission.
This class provides functionalities to manage a collection of wireless network interfaces, each associated with
specific frequencies. It includes methods to add and remove wireless interfaces, and handle data transmission
across these interfaces.
"""
wireless_interfaces: Dict[str, WirelessNetworkInterface] = Field(default_factory=lambda: {})
wireless_interfaces_by_frequency: Dict[AirSpaceFrequency, List[WirelessNetworkInterface]] = Field(
default_factory=lambda: {}
)
bandwidth_load: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {})
frequency_max_capacity_mbps_: Dict[AirSpaceFrequency, float] = Field(default_factory=lambda: {})
def get_frequency_max_capacity_mbps(self, frequency: AirSpaceFrequency) -> float:
"""
Retrieves the maximum data transmission capacity for a specified frequency.
This method checks a dictionary holding custom maximum capacities. If the frequency is found, it returns the
custom set maximum capacity. If the frequency is not found in the dictionary, it defaults to the standard
maximum data rate associated with that frequency.
:param frequency: The frequency for which the maximum capacity is queried.
:return: The maximum capacity in Mbps for the specified frequency.
"""
if frequency in self.frequency_max_capacity_mbps_:
return self.frequency_max_capacity_mbps_[frequency]
return frequency.maximum_data_rate_mbps
def set_frequency_max_capacity_mbps(self, cfg: Dict[AirSpaceFrequency, float]):
"""
Sets custom maximum data transmission capacities for multiple frequencies.
:param cfg: A dictionary mapping frequencies to their new maximum capacities in Mbps.
"""
self.frequency_max_capacity_mbps_ = cfg
for freq, mbps in cfg.items():
print(f"Overriding {freq} max capacity as {mbps:.3f} mbps")
def show_bandwidth_load(self, markdown: bool = False):
"""
Prints a table of the current bandwidth load for each frequency on the airspace.
This method prints a tabulated view showing the utilisation of available bandwidth capacities for all
frequencies. The table includes the current capacity usage as a percentage of the maximum capacity, alongside
the absolute maximum capacity values in Mbps.
:param markdown: Flag indicating if output should be in markdown format.
"""
headers = ["Frequency", "Current Capacity (%)", "Maximum Capacity (Mbit)"]
table = PrettyTable(headers)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = "Airspace Frequency Channel Loads"
for frequency, load in self.bandwidth_load.items():
maximum_capacity = self.get_frequency_max_capacity_mbps(frequency)
load_percent = load / maximum_capacity if maximum_capacity > 0 else 0.0
if load_percent > 1.0:
load_percent = 1.0
table.add_row([format_hertz(frequency.value), f"{load_percent:.0%}", f"{maximum_capacity:.3f}"])
print(table)
def show_wireless_interfaces(self, markdown: bool = False):
"""
Prints a table of wireless interfaces in the airspace.
:param markdown: Flag indicating if output should be in markdown format.
"""
headers = [
"Connected Node",
"MAC Address",
"IP Address",
"Subnet Mask",
"Frequency",
"Speed (Mbps)",
"Status",
]
table = PrettyTable(headers)
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = "Devices on Air Space"
for interface in self.wireless_interfaces.values():
status = "Enabled" if interface.enabled else "Disabled"
table.add_row(
[
interface._connected_node.hostname, # noqa
interface.mac_address,
interface.ip_address if hasattr(interface, "ip_address") else None,
interface.subnet_mask if hasattr(interface, "subnet_mask") else None,
format_hertz(interface.frequency.value),
f"{interface.speed:.3f}",
status,
]
)
print(table.get_string(sortby="Frequency"))
def show(self, markdown: bool = False):
"""
Prints a summary of the current state of the airspace, including both wireless interfaces and bandwidth loads.
This method is a convenient wrapper that calls two separate methods to display detailed tables: one for
wireless interfaces and another for bandwidth load across all frequencies managed within the airspace. It
provides a holistic view of the operational status and performance metrics of the airspace.
:param markdown: Flag indicating if output should be in markdown format.
"""
self.show_wireless_interfaces(markdown)
self.show_bandwidth_load(markdown)
def add_wireless_interface(self, wireless_interface: WirelessNetworkInterface):
"""
Adds a wireless network interface to the airspace if it's not already present.
:param wireless_interface: The wireless network interface to be added.
"""
if wireless_interface.mac_address not in self.wireless_interfaces:
self.wireless_interfaces[wireless_interface.mac_address] = wireless_interface
if wireless_interface.frequency not in self.wireless_interfaces_by_frequency:
self.wireless_interfaces_by_frequency[wireless_interface.frequency] = []
self.wireless_interfaces_by_frequency[wireless_interface.frequency].append(wireless_interface)
def remove_wireless_interface(self, wireless_interface: WirelessNetworkInterface):
"""
Removes a wireless network interface from the airspace if it's present.
:param wireless_interface: The wireless network interface to be removed.
"""
if wireless_interface.mac_address in self.wireless_interfaces:
self.wireless_interfaces.pop(wireless_interface.mac_address)
self.wireless_interfaces_by_frequency[wireless_interface.frequency].remove(wireless_interface)
def clear(self):
"""
Clears all wireless network interfaces and their frequency associations from the airspace.
After calling this method, the airspace will contain no wireless network interfaces, and transmissions cannot
occur until new interfaces are added again.
"""
self.wireless_interfaces.clear()
self.wireless_interfaces_by_frequency.clear()
def reset_bandwidth_load(self):
"""
Resets the bandwidth load tracking for all frequencies in the airspace.
This method clears the current load metrics for all operating frequencies, effectively setting the load to zero.
"""
self.bandwidth_load = {}
def can_transmit_frame(self, frame: Frame, sender_network_interface: WirelessNetworkInterface) -> bool:
"""
Determines if a frame can be transmitted by the sender network interface based on the current bandwidth load.
This method checks if adding the size of the frame to the current bandwidth load of the frequency used by the
sender network interface would exceed the maximum allowed bandwidth for that frequency. It returns True if the
frame can be transmitted without exceeding the limit, and False otherwise.
:param frame: The frame to be transmitted, used to check its size against the frequency's bandwidth limit.
:param sender_network_interface: The network interface attempting to transmit the frame, used to determine the
relevant frequency and its current bandwidth load.
:return: True if the frame can be transmitted within the bandwidth limit, False if it would exceed the limit.
"""
if sender_network_interface.frequency not in self.bandwidth_load:
self.bandwidth_load[sender_network_interface.frequency] = 0.0
return self.bandwidth_load[
sender_network_interface.frequency
] + frame.size_Mbits <= self.get_frequency_max_capacity_mbps(sender_network_interface.frequency)
def transmit(self, frame: Frame, sender_network_interface: WirelessNetworkInterface):
"""
Transmits a frame to all enabled wireless network interfaces on a specific frequency within the airspace.
This ensures that a wireless interface does not receive its own transmission.
:param frame: The frame to be transmitted.
:param sender_network_interface: The wireless network interface sending the frame. This interface will be
excluded from the list of receivers to prevent it from receiving its own transmission.
"""
self.bandwidth_load[sender_network_interface.frequency] += frame.size_Mbits
for wireless_interface in self.wireless_interfaces_by_frequency.get(sender_network_interface.frequency, []):
if wireless_interface != sender_network_interface and wireless_interface.enabled:
wireless_interface.receive_frame(frame)
class WirelessNetworkInterface(NetworkInterface, ABC):
@@ -185,13 +344,18 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
:param frame: The network frame to be sent.
:return: True if the frame is sent successfully, False if the network interface is disabled.
"""
if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture_outbound(frame)
self.airspace.transmit(frame, self)
return True
# Cannot send Frame as the network interface is not enabled
return False
if not self.enabled:
return False
if not self.airspace.can_transmit_frame(frame, self):
# Drop frame for now. Queuing will happen here (probably) if it's done in the future.
self._connected_node.sys_log.info(f"{self}: Frame dropped as Link is at capacity")
return False
super().send_frame(frame)
frame.set_sent_timestamp()
self.pcap.capture_outbound(frame)
self.airspace.transmit(frame, self)
return True
def receive_frame(self, frame: Frame) -> bool:
"""

View File

@@ -96,6 +96,8 @@ class Network(SimComponent):
"""Apply pre-timestep logic."""
super().pre_timestep(timestep)
self.airspace.reset_bandwidth_load()
for node in self.nodes.values():
node.pre_timestep(timestep)

View File

@@ -78,7 +78,7 @@ class NetworkInterface(SimComponent, ABC):
mac_address: str = Field(default_factory=generate_mac_address)
"The MAC address of the interface."
speed: int = 100
speed: float = 100.0
"The speed of the interface in Mbps. Default is 100 Mbps."
mtu: int = 1500
@@ -124,10 +124,25 @@ class NetworkInterface(SimComponent, ABC):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
_is_network_interface_enabled = NetworkInterface._EnabledValidator(network_interface=self)
_is_network_interface_disabled = NetworkInterface._DisabledValidator(network_interface=self)
rm = super()._init_request_manager()
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
rm.add_request(
"enable",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.enable()),
validator=_is_network_interface_disabled,
),
)
rm.add_request(
"disable",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.disable()),
validator=_is_network_interface_enabled,
),
)
return rm
@@ -326,6 +341,50 @@ class NetworkInterface(SimComponent, ABC):
super().pre_timestep(timestep)
self.traffic = {}
class _EnabledValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the NetworkInterface is enabled.
This is useful because most actions should be being resolved if the NetworkInterface is disabled.
"""
network_interface: NetworkInterface
"""Save a reference to the node instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the NetworkInterface is enabled or not."""
return self.network_interface.enabled
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return (
f"Cannot perform request on NetworkInterface "
f"'{self.network_interface.mac_address}' because it is not enabled."
)
class _DisabledValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the NetworkInterface is disabled.
This is useful because some actions should be being resolved if the NetworkInterface is disabled.
"""
network_interface: NetworkInterface
"""Save a reference to the node instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the NetworkInterface is disabled or not."""
return not self.network_interface.enabled
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return (
f"Cannot perform request on NetworkInterface "
f"'{self.network_interface.mac_address}' because it is not disabled."
)
class WiredNetworkInterface(NetworkInterface, ABC):
"""
@@ -434,14 +493,17 @@ class WiredNetworkInterface(NetworkInterface, ABC):
:param frame: The network frame to be sent.
:return: True if the frame is sent, False if the Network Interface is disabled or not connected to a link.
"""
if not self.enabled:
return False
if not self._connected_link.can_transmit_frame(frame):
# Drop frame for now. Queuing will happen here (probably) if it's done in the future.
self._connected_node.sys_log.info(f"{self}: Frame dropped as Link is at capacity")
return False
super().send_frame(frame)
if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture_outbound(frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
# Cannot send Frame as the NIC is not enabled
return False
frame.set_sent_timestamp()
self.pcap.capture_outbound(frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
@abstractmethod
def receive_frame(self, frame: Frame) -> bool:
@@ -672,12 +734,21 @@ class Link(SimComponent):
"""
return self.endpoint_a.enabled and self.endpoint_b.enabled
def _can_transmit(self, frame: Frame) -> bool:
def can_transmit_frame(self, frame: Frame) -> bool:
"""
Determines whether a frame can be transmitted considering the current Link load and the Link's bandwidth.
This method assesses if the transmission of a given frame is possible without exceeding the Link's total
bandwidth capacity. It checks if the current load of the Link plus the size of the frame (expressed in Mbps)
would remain within the defined bandwidth limits. The transmission is only feasible if the Link is active
('up') and the total load including the new frame does not surpass the bandwidth limit.
:param frame: The frame intended for transmission, which contains its size in Mbps.
:return: True if the frame can be transmitted without exceeding the bandwidth limit, False otherwise.
"""
if self.is_up:
frame_size_Mbits = frame.size_Mbits # noqa - Leaving it as Mbits as this is how they're expressed
# return self.current_load + frame_size_Mbits <= self.bandwidth
# TODO: re add this check once packet size limiting and MTU checks are implemented
return True
return self.current_load + frame.size_Mbits <= self.bandwidth
return False
def transmit_frame(self, sender_nic: WiredNetworkInterface, frame: Frame) -> bool:
@@ -688,11 +759,6 @@ class Link(SimComponent):
:param frame: The network frame to be sent.
:return: True if the Frame can be sent, otherwise False.
"""
can_transmit = self._can_transmit(frame)
if not can_transmit:
_LOGGER.debug(f"Cannot transmit frame as {self} is at capacity")
return False
receiver = self.endpoint_a
if receiver == sender_nic:
receiver = self.endpoint_b
@@ -872,6 +938,25 @@ class Node(SimComponent):
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned on."
class _NodeIsOffValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the node is off.
This is useful because some actions require the node to be in an off state.
"""
node: Node
"""Save a reference to the node instance."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the node is on or off."""
return self.node.operating_state == NodeOperatingState.OFF
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return f"Cannot perform request on node '{self.node.hostname}' because it is not turned off."
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager.
@@ -934,6 +1019,7 @@ class Node(SimComponent):
return RequestResponse.from_bool(False)
_node_is_on = Node._NodeIsOnValidator(node=self)
_node_is_off = Node._NodeIsOffValidator(node=self)
rm = super()._init_request_manager()
# since there are potentially many services, create an request manager that can map service name
@@ -963,7 +1049,12 @@ class Node(SimComponent):
func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on
),
)
rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on())))
rm.add_request(
"startup",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.power_on()), validator=_node_is_off
),
)
rm.add_request(
"reset",
RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on),

View File

@@ -58,12 +58,16 @@ class SwitchPort(WiredNetworkInterface):
:param frame: The network frame to be sent.
:return: A boolean indicating whether the frame was successfully sent.
"""
if self.enabled:
self.pcap.capture_outbound(frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
# Cannot send Frame as the SwitchPort is not enabled
return False
if not self.enabled:
return False
if not self._connected_link.can_transmit_frame(frame):
# Drop frame for now. Queuing will happen here (probably) if it's done in the future.
self._connected_node.sys_log.info(f"{self}: Frame dropped as Link is at capacity")
return False
self.pcap.capture_outbound(frame)
self._connected_link.transmit_frame(sender_nic=self, frame=frame)
return True
def receive_frame(self, frame: Frame) -> bool:
"""

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from ipaddress import IPv4Address
from typing import Any, Dict, Union
from typing import Any, Dict, Optional, Union
from pydantic import validate_call
@@ -153,7 +153,7 @@ class WirelessRouter(Router):
self,
ip_address: IPV4Address,
subnet_mask: IPV4Address,
frequency: AirSpaceFrequency = AirSpaceFrequency.WIFI_2_4,
frequency: Optional[AirSpaceFrequency] = AirSpaceFrequency.WIFI_2_4,
):
"""
Configures a wireless access point (WAP).
@@ -170,13 +170,20 @@ class WirelessRouter(Router):
enum. This determines the frequency band (e.g., 2.4 GHz or 5 GHz) the access point will use for wireless
communication. Default is AirSpaceFrequency.WIFI_2_4.
"""
if not frequency:
frequency = AirSpaceFrequency.WIFI_2_4
self.sys_log.info("Configuring wireless access point")
self.wireless_access_point.disable() # Temporarily disable the WAP for reconfiguration
network_interface = self.network_interface[1]
network_interface.ip_address = ip_address
network_interface.subnet_mask = subnet_mask
self.sys_log.info(f"Configured WAP {network_interface}")
self.wireless_access_point.frequency = frequency # Set operating frequency
self.wireless_access_point.enable() # Re-enable the WAP with new settings
self.sys_log.info(f"Configured WAP {network_interface}")
@property
def router_interface(self) -> RouterInterface:

View File

@@ -133,10 +133,11 @@ class Frame(BaseModel):
def size(self) -> float: # noqa - Keep it as MBits as this is how they're expressed
"""The size of the Frame in Bytes."""
# get the payload size if it is a data packet
payload_size = 0.0
if isinstance(self.payload, DataPacket):
return self.payload.get_packet_size()
payload_size = self.payload.get_packet_size()
return float(len(self.model_dump_json().encode("utf-8")))
return float(len(self.model_dump_json().encode("utf-8"))) + payload_size
@property
def size_Mbits(self) -> float: # noqa - Keep it as MBits as this is how they're expressed

View File

@@ -1,10 +1,12 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from enum import Enum
from typing import Any, ClassVar, Dict, Optional, Set, Type
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
@@ -64,9 +66,27 @@ class Application(IOSoftware):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
rm = super()._init_request_manager()
_is_application_running = Application._StateValidator(application=self, state=ApplicationOperatingState.RUNNING)
rm.add_request("close", RequestType(func=lambda request, context: RequestResponse.from_bool(self.close())))
rm = super()._init_request_manager()
rm.add_request(
"scan",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_is_application_running
),
)
rm.add_request(
"close",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.close()), validator=_is_application_running
),
)
rm.add_request(
"fix",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.fix()), validator=_is_application_running
),
)
return rm
@abstractmethod
@@ -169,3 +189,28 @@ class Application(IOSoftware):
:return: True if successful, False otherwise.
"""
return super().receive(payload=payload, session_id=session_id, **kwargs)
class _StateValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the application is in the correct state.
This is useful because most actions require the application to be in a specific state.
"""
application: Application
"""Save a reference to the application instance."""
state: ApplicationOperatingState
"""The state of the application to validate."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the application is in the state we are validating for."""
return self.application.operating_state == self.state
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return (
f"Cannot perform request on application '{self.application.name}' because it is not in the "
f"{self.state.name} state."
)

View File

@@ -1,11 +1,13 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from __future__ import annotations
from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Optional
from primaite import getLogger
from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType
from primaite.interface.request import RequestFormat, RequestResponse
from primaite.simulator.core import RequestManager, RequestPermissionValidator, RequestType
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
_LOGGER = getLogger(__name__)
@@ -40,6 +42,7 @@ class Service(IOSoftware):
restart_duration: int = 5
"How many timesteps does it take to restart this service."
restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
@@ -86,15 +89,61 @@ class Service(IOSoftware):
More information in user guide and docstring for SimComponent._init_request_manager.
"""
_is_service_running = Service._StateValidator(service=self, state=ServiceOperatingState.RUNNING)
_is_service_stopped = Service._StateValidator(service=self, state=ServiceOperatingState.STOPPED)
_is_service_paused = Service._StateValidator(service=self, state=ServiceOperatingState.PAUSED)
_is_service_disabled = Service._StateValidator(service=self, state=ServiceOperatingState.DISABLED)
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())))
rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop())))
rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start())))
rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause())))
rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume())))
rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart())))
rm.add_request(
"scan",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_is_service_running
),
)
rm.add_request(
"stop",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.stop()), validator=_is_service_running
),
)
rm.add_request(
"start",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.start()), validator=_is_service_stopped
),
)
rm.add_request(
"pause",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.pause()), validator=_is_service_running
),
)
rm.add_request(
"resume",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.resume()), validator=_is_service_paused
),
)
rm.add_request(
"restart",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.restart()), validator=_is_service_running
),
)
rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable())))
rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable())))
rm.add_request(
"enable",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.enable()), validator=_is_service_disabled
),
)
rm.add_request(
"fix",
RequestType(
func=lambda request, context: RequestResponse.from_bool(self.fix()), validator=_is_service_running
),
)
return rm
@abstractmethod
@@ -191,3 +240,28 @@ class Service(IOSoftware):
self.sys_log.debug(f"Restarting finished for service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING
self.restart_countdown -= 1
class _StateValidator(RequestPermissionValidator):
"""
When requests come in, this validator will only let them through if the service is in the correct state.
This is useful because most actions require the service to be in a specific state.
"""
service: Service
"""Save a reference to the service instance."""
state: ServiceOperatingState
"""The state of the service to validate."""
def __call__(self, request: RequestFormat, context: Dict) -> bool:
"""Return whether the service is in the state we are validating for."""
return self.service.operating_state == self.state
@property
def fail_message(self) -> str:
"""Message that is reported when a request is rejected by this validator."""
return (
f"Cannot perform request on service '{self.service.name}' because it is not in the "
f"{self.state.name} state."
)

View File

@@ -82,12 +82,31 @@ def config_callback(
show_default=False,
),
] = None,
agent_log_level: Annotated[
LogLevel,
typer.Option(
"--agent-log-level",
"-level",
click_type=click.Choice(LogLevel._member_names_, case_sensitive=False),
help="The level of agent behaviour logs to output.",
show_default=False,
),
] = None,
output_sys_logs: Annotated[
bool,
typer.Option(
"--output-sys-logs/--no-sys-logs", "-sys/-nsys", help="Output system logs to file.", show_default=False
),
] = None,
output_agent_logs: Annotated[
bool,
typer.Option(
"--output-agent-logs/--no-agent-logs",
"-agent/-nagent",
help="Output agent logs to file.",
show_default=False,
),
] = None,
output_pcap_logs: Annotated[
bool,
typer.Option(
@@ -109,10 +128,18 @@ def config_callback(
PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] = ctx.params.get("sys_log_level")
print(f"PrimAITE dev-mode config updated sys_log_level={ctx.params.get('sys_log_level')}")
if ctx.params.get("agent_log_level") is not None:
PRIMAITE_CONFIG["developer_mode"]["agent_log_level"] = ctx.params.get("agent_log_level")
print(f"PrimAITE dev-mode config updated agent_log_level={ctx.params.get('agent_log_level')}")
if output_sys_logs is not None:
PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] = output_sys_logs
print(f"PrimAITE dev-mode config updated {output_sys_logs=}")
if output_agent_logs is not None:
PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] = output_agent_logs
print(f"PrimAITE dev-mode config updated {output_agent_logs=}")
if output_pcap_logs is not None:
PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] = output_pcap_logs
print(f"PrimAITE dev-mode config updated {output_pcap_logs=}")

View File

@@ -9,6 +9,9 @@ io_settings:
save_pcap_logs: true
save_sys_logs: true
sys_log_level: WARNING
agent_log_level: INFO
save_agent_logs: true
write_agent_log_to_terminal: True
game:

File diff suppressed because it is too large Load Diff

View File

@@ -41,6 +41,12 @@ agents:
options:
source_node: client_1
target_ip_address: 192.168.10.0/24
target_port:
- 21
- 53
- 80
- 123
- 219
reward_function:
reward_components:

View File

@@ -177,6 +177,9 @@ simulation:
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: NMAP
options:
fix_duration: 1
- type: RansomwareScript
options:
fix_duration: 1

View File

@@ -243,25 +243,25 @@ agents:
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
file_id: 0
13:
action: "NODE_SERVICE_FIX"
@@ -272,22 +272,22 @@ agents:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
folder_id: 0
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
folder_id: 0
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
folder_id: 0
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
folder_id: 0
18:
action: "NODE_OS_SCAN"
options:
@@ -518,11 +518,22 @@ agents:
nodes:
- node_name: domain_controller
- node_name: web_server
applications:
- application_name: DatabaseClient
services:
- service_name: WebServer
- node_name: database_server
folders:
- folder_name: database
files:
- file_name: database.db
services:
- service_name: DatabaseService
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
@@ -557,6 +568,7 @@ agents:
agent_settings:
flatten_obs: true
action_masking: true
@@ -634,6 +646,8 @@ simulation:
dns_server: 192.168.1.10
services:
- type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- type: server
hostname: backup_server

View File

@@ -0,0 +1,81 @@
game:
max_episode_length: 256
ports:
- ARP
protocols:
- ICMP
- TCP
- UDP
simulation:
network:
airspace:
frequency_max_capacity_mbps:
WIFI_2_4: 123.45
WIFI_5: 0.0
nodes:
- type: computer
hostname: pc_a
ip_address: 192.168.0.2
subnet_mask: 255.255.255.0
default_gateway: 192.168.0.1
start_up_duration: 0
- type: computer
hostname: pc_b
ip_address: 192.168.2.2
subnet_mask: 255.255.255.0
default_gateway: 192.168.2.1
start_up_duration: 0
- type: wireless_router
hostname: router_1
start_up_duration: 0
router_interface:
ip_address: 192.168.0.1
subnet_mask: 255.255.255.0
wireless_access_point:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
frequency: WIFI_2_4
acl:
1:
action: PERMIT
routes:
- address: 192.168.2.0 # PC B subnet
subnet_mask: 255.255.255.0
next_hop_ip_address: 192.168.1.2
metric: 0
- type: wireless_router
hostname: router_2
start_up_duration: 0
router_interface:
ip_address: 192.168.2.1
subnet_mask: 255.255.255.0
wireless_access_point:
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
frequency: WIFI_2_4
acl:
1:
action: PERMIT
routes:
- address: 192.168.0.0 # PC A subnet
subnet_mask: 255.255.255.0
next_hop_ip_address: 192.168.1.1
metric: 0
links:
- endpoint_a_hostname: pc_a
endpoint_a_port: 1
endpoint_b_hostname: router_1
endpoint_b_port: 2
- endpoint_a_hostname: pc_b
endpoint_a_port: 1
endpoint_b_hostname: router_2
endpoint_b_port: 2

View File

@@ -0,0 +1,81 @@
game:
max_episode_length: 256
ports:
- ARP
protocols:
- ICMP
- TCP
- UDP
simulation:
network:
airspace:
frequency_max_capacity_mbps:
WIFI_2_4: 0.0
WIFI_5: 0.0
nodes:
- type: computer
hostname: pc_a
ip_address: 192.168.0.2
subnet_mask: 255.255.255.0
default_gateway: 192.168.0.1
start_up_duration: 0
- type: computer
hostname: pc_b
ip_address: 192.168.2.2
subnet_mask: 255.255.255.0
default_gateway: 192.168.2.1
start_up_duration: 0
- type: wireless_router
hostname: router_1
start_up_duration: 0
router_interface:
ip_address: 192.168.0.1
subnet_mask: 255.255.255.0
wireless_access_point:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
frequency: WIFI_2_4
acl:
1:
action: PERMIT
routes:
- address: 192.168.2.0 # PC B subnet
subnet_mask: 255.255.255.0
next_hop_ip_address: 192.168.1.2
metric: 0
- type: wireless_router
hostname: router_2
start_up_duration: 0
router_interface:
ip_address: 192.168.2.1
subnet_mask: 255.255.255.0
wireless_access_point:
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
frequency: WIFI_2_4
acl:
1:
action: PERMIT
routes:
- address: 192.168.0.0 # PC A subnet
subnet_mask: 255.255.255.0
next_hop_ip_address: 192.168.1.1
metric: 0
links:
- endpoint_a_hostname: pc_a
endpoint_a_port: 1
endpoint_b_hostname: router_1
endpoint_b_port: 2
- endpoint_a_hostname: pc_b
endpoint_a_port: 1
endpoint_b_hostname: router_2
endpoint_b_port: 2

View File

@@ -3,6 +3,7 @@ from typing import Any, Dict, Tuple
import pytest
import yaml
from ray import init as rayinit
from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager
@@ -29,6 +30,7 @@ from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
rayinit(local_mode=True)
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
@@ -87,7 +89,10 @@ def service_class():
@pytest.fixture(scope="function")
def application(file_system) -> DummyApplication:
return DummyApplication(
name="DummyApplication", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="dummy_application")
name="DummyApplication",
port=Port.ARP,
file_system=file_system,
sys_log=SysLog(hostname="dummy_application"),
)
@@ -252,8 +257,7 @@ def example_network() -> Network:
server_2.power_on()
network.connect(endpoint_b=server_2.network_interface[1], endpoint_a=switch_1.network_interface[2])
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
router_1.acl.add_rule(action=ACLAction.PERMIT, position=1)
assert all(link.is_up for link in network.links.values())

View File

@@ -0,0 +1 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK

View File

@@ -0,0 +1,156 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Dict
import yaml
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.core.rl_module.marl_module import MultiAgentRLModuleSpec
from ray.rllib.core.rl_module.rl_module import SingleAgentRLModuleSpec
from ray.rllib.examples.rl_modules.classes.action_masking_rlm import ActionMaskingTorchRLModule
from sb3_contrib import MaskablePPO
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
from tests import TEST_ASSETS_ROOT
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
MARL_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
def test_sb3_action_masking(monkeypatch):
# There's no simple way of capturing what the action mask was at every step, therefore we are mocking the action
# mask function here to save the output of the action mask method and pass through the result back to the agent.
old_action_mask_method = PrimaiteGame.action_mask
mask_history = []
def cache_action_mask(obj, agent_name):
mask = old_action_mask_method(obj, agent_name)
mask_history.append(mask)
return mask
# Even though it's easy to know which CAOS action the agent took by looking at agent history, we don't know which
# action map action integer that was, therefore we cache it by using monkeypatch
action_num_history = []
def cache_step(env, action: int):
action_num_history.append(action)
return PrimaiteGymEnv.step(env, action)
monkeypatch.setattr(PrimaiteGame, "action_mask", cache_action_mask)
env = PrimaiteGymEnv(CFG_PATH)
monkeypatch.setattr(env, "step", lambda action: cache_step(env, action))
model = MaskablePPO("MlpPolicy", env, gamma=0.4, seed=32, batch_size=32)
model.learn(256)
assert len(action_num_history) == len(mask_history) > 0
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
assert any([not all(x) for x in mask_history])
# When the agent takes action N from its action map, we need to have a look at the action mask and make sure that
# the N-th entry was True, meaning that it was a valid action at that step.
# This plucks out the mask history at step i, and at action entry a and checks that it's set to True, and this
# happens for all steps i in the episode
assert all(mask_history[i][a] for i, a in enumerate(action_num_history))
monkeypatch.undo()
def test_ray_single_agent_action_masking(monkeypatch):
"""Check that a Ray agent uses the action mask and never chooses invalid actions."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
for agent in cfg["agents"]:
if agent["ref"] == "defender":
agent["agent_settings"]["flatten_obs"] = True
# There's no simple way of capturing what the action mask was at every step, therefore we are mocking the step
# function to save the action mask and the agent's chosen action to a local variable.
old_step_method = PrimaiteRayEnv.step
action_num_history = []
mask_history = []
def cache_step(self, action: int):
action_num_history.append(action)
obs, *_ = old_step_method(self, action)
action_mask = obs["action_mask"]
mask_history.append(action_mask)
return obs, *_
monkeypatch.setattr(PrimaiteRayEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs))
# Configure Ray PPO to use action masking by using the ActionMaskingTorchRLModule
config = (
PPOConfig()
.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
.environment(env=PrimaiteRayEnv, env_config=cfg, action_mask_key="action_mask")
.rl_module(rl_module_spec=SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule))
.env_runners(num_env_runners=0)
.training(train_batch_size=128)
)
algo = config.build()
algo.train()
assert len(action_num_history) == len(mask_history) > 0
# Make sure the masks had at least some False entries, if it was all True then the mask was disabled
assert any([not all(x) for x in mask_history])
# When the agent takes action N from its action map, we need to have a look at the action mask and make sure that
# the N-th action was valid.
# The first step uses the action mask provided by the reset method, so we are only checking from the second step
# onward, that's why we need to use mask_history[:-1] and action_num_history[1:]
assert all(mask_history[:-1][i][a] for i, a in enumerate(action_num_history[1:]))
monkeypatch.undo()
def test_ray_multi_agent_action_masking(monkeypatch):
"""Check that Ray agents never take invalid actions when using MARL."""
with open(MARL_PATH, "r") as f:
cfg = yaml.safe_load(f)
old_step_method = PrimaiteRayMARLEnv.step
action_num_history = {"defender_1": [], "defender_2": []}
mask_history = {"defender_1": [], "defender_2": []}
def cache_step(self, actions: Dict[str, int]):
for agent_name, action in actions.items():
action_num_history[agent_name].append(action)
obs, *_ = old_step_method(self, actions)
for (
agent_name,
o,
) in obs.items():
mask_history[agent_name].append(o["action_mask"])
return obs, *_
monkeypatch.setattr(PrimaiteRayMARLEnv, "step", lambda *args, **kwargs: cache_step(*args, **kwargs))
config = (
PPOConfig()
.multi_agent(
policies={
"defender_1",
"defender_2",
}, # These names are the same as the agents defined in the example config.
policy_mapping_fn=lambda agent_id, *args, **kwargs: agent_id,
)
.api_stack(enable_rl_module_and_learner=True, enable_env_runner_and_connector_v2=True)
.environment(env=PrimaiteRayMARLEnv, env_config=cfg, action_mask_key="action_mask")
.rl_module(
rl_module_spec=MultiAgentRLModuleSpec(
module_specs={
"defender_1": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
"defender_2": SingleAgentRLModuleSpec(module_class=ActionMaskingTorchRLModule),
}
)
)
.env_runners(num_env_runners=0)
.training(train_batch_size=128)
)
algo = config.build()
algo.train()
for agent_name in ["defender_1", "defender_2"]:
act_hist = action_num_history[agent_name]
mask_hist = mask_history[agent_name]
assert len(act_hist) == len(mask_hist) > 0
assert any([not all(x) for x in mask_hist])
assert all(mask_hist[:-1][i][a] for i, a in enumerate(act_hist[1:]))
monkeypatch.undo()

View File

@@ -1,7 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import ray
import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from primaite.session.ray_envs import PrimaiteRayMARLEnv
@@ -12,12 +10,9 @@ MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
def test_rllib_multi_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
with open(MULTI_AGENT_PATH, "r") as f:
cfg = yaml.safe_load(f)
ray.init()
config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config=cfg)
@@ -28,15 +23,5 @@ def test_rllib_multi_agent_compatibility():
)
.training(train_batch_size=128)
)
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": 128},
checkpoint_config=air.CheckpointConfig(
checkpoint_frequency=10,
),
),
param_space=config,
).fit()
ray.shutdown()
algo = config.build()
algo.train()

View File

@@ -3,7 +3,6 @@ import tempfile
from pathlib import Path
import pytest
import ray
import yaml
from ray.rllib.algorithms import ppo
@@ -20,9 +19,6 @@ def test_rllib_single_agent_compatibility():
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = {
"env": PrimaiteRayEnv,
@@ -41,4 +37,3 @@ def test_rllib_single_agent_compatibility():
assert save_file.exists()
save_file.unlink() # clean up
ray.shutdown()

View File

@@ -20,7 +20,7 @@ def test_sb3_compatibility():
gym = PrimaiteGymEnv(env_config=cfg)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)
model.learn(total_timesteps=256)
save_path = Path(tempfile.gettempdir()) / "model.zip"
model.save(save_path)

View File

@@ -65,25 +65,25 @@ class TestPrimaiteEnvironment:
cfg = yaml.safe_load(f)
env = PrimaiteRayMARLEnv(env_config=cfg)
assert set(env._agent_ids) == {"defender1", "defender2"}
assert set(env._agent_ids) == {"defender_1", "defender_2"}
assert len(env.agents) == 2
defender1 = env.agents["defender1"]
defender2 = env.agents["defender2"]
assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54
assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38
defender_1 = env.agents["defender_1"]
defender_2 = env.agents["defender_2"]
assert (num_actions_1 := len(defender_1.action_manager.action_map)) == 78
assert (num_actions_2 := len(defender_2.action_manager.action_map)) == 78
# ensure we can run all valid actions without error
for act_1 in range(num_actions_1):
env.step({"defender1": act_1, "defender2": 0})
env.step({"defender_1": act_1, "defender_2": 0})
for act_2 in range(num_actions_2):
env.step({"defender1": 0, "defender2": act_2})
env.step({"defender_1": 0, "defender_2": act_2})
# ensure we get error when taking an invalid action
with pytest.raises(KeyError):
env.step({"defender1": num_actions_1, "defender2": 0})
env.step({"defender_1": num_actions_1, "defender_2": 0})
with pytest.raises(KeyError):
env.step({"defender1": 0, "defender2": num_actions_2})
env.step({"defender_1": 0, "defender_2": num_actions_2})
def test_error_thrown_on_bad_configuration(self):
"""Make sure we throw an error when the config is bad."""

View File

@@ -67,7 +67,7 @@ def test_dev_mode_config_sys_log_level():
# check defaults
assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "DEBUG" # DEBUG by default
result = cli(["dev-mode", "config", "-level", "WARNING"])
result = cli(["dev-mode", "config", "--sys-log-level", "WARNING"])
assert "sys_log_level=WARNING" in result.output # should print correct value
@@ -78,10 +78,30 @@ def test_dev_mode_config_sys_log_level():
assert "sys_log_level=INFO" in result.output # should print correct value
# config should reflect that log level is WARNING
# config should reflect that log level is INFO
assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "INFO"
def test_dev_mode_config_agent_log_level():
"""Check that the agent log level can be changed via CLI."""
# check defaults
assert PRIMAITE_CONFIG["developer_mode"]["agent_log_level"] == "DEBUG" # DEBUG by default
result = cli(["dev-mode", "config", "-level", "WARNING"])
assert "agent_log_level=WARNING" in result.output # should print correct value
# config should reflect that log level is WARNING
assert PRIMAITE_CONFIG["developer_mode"]["agent_log_level"] == "WARNING"
result = cli(["dev-mode", "config", "--agent-log-level", "INFO"])
assert "agent_log_level=INFO" in result.output # should print correct value
# config should reflect that log level is INFO
assert PRIMAITE_CONFIG["developer_mode"]["agent_log_level"] == "INFO"
def test_dev_mode_config_sys_logs_enable_disable():
"""Test that the system logs output can be enabled or disabled."""
# check defaults
@@ -112,6 +132,36 @@ def test_dev_mode_config_sys_logs_enable_disable():
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] is False
def test_dev_mode_config_agent_logs_enable_disable():
"""Test that the agent logs output can be enabled or disabled."""
# check defaults
assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] is False # False by default
result = cli(["dev-mode", "config", "--output-agent-logs"])
assert "output_agent_logs=True" in result.output # should print correct value
# config should reflect that output_agent_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"]
result = cli(["dev-mode", "config", "--no-agent-logs"])
assert "output_agent_logs=False" in result.output # should print correct value
# config should reflect that output_agent_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] is False
result = cli(["dev-mode", "config", "-agent"])
assert "output_agent_logs=True" in result.output # should print correct value
# config should reflect that output_agent_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"]
result = cli(["dev-mode", "config", "-nagent"])
assert "output_agent_logs=False" in result.output # should print correct value
# config should reflect that output_agent_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_agent_logs"] is False
def test_dev_mode_config_pcap_logs_enable_disable():
"""Test that the pcap logs output can be enabled or disabled."""
# check defaults

View File

@@ -35,3 +35,7 @@ def test_io_settings():
assert env.io.settings.save_step_metadata is False
assert env.io.settings.write_sys_log_to_terminal is False # false by default
assert env.io.settings.save_agent_logs is True
assert env.io.settings.agent_log_level is LogLevel.INFO
assert env.io.settings.write_agent_log_to_terminal is True # Set to True by the config file.

View File

@@ -1,35 +1,23 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import copy
from ipaddress import IPv4Address
from pathlib import Path
from typing import Union
import yaml
from primaite.config.load import data_manipulation_config_path
from primaite.game.agent.interface import ProxyAgent
from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent
from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.container import Network
from primaite.game.game import PrimaiteGame, SERVICE_TYPES_MAPPING
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.simulator.system.services.ntp.ntp_server import NTPServer
from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT
TEST_CONFIG = TEST_ASSETS_ROOT / "configs/software_fix_duration.yaml"
ONE_ITEM_CONFIG = TEST_ASSETS_ROOT / "configs/fix_duration_one_item.yaml"
TestApplications = ["DummyApplication", "BroadcastTestClient"]
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
@@ -62,9 +50,12 @@ def test_fix_duration_set_from_config():
assert client_1.software_manager.software.get(service).fixing_duration == 3
# in config - applications take 1 timestep to fix
for applications in APPLICATION_TYPES_MAPPING:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 1
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
for application in applications:
assert client_1.software_manager.software.get(application) is not None
assert client_1.software_manager.software.get(application).fixing_duration == 1
def test_fix_duration_for_one_item():
@@ -80,8 +71,9 @@ def test_fix_duration_for_one_item():
assert client_1.software_manager.software.get(service).fixing_duration == 2
# in config - applications take 1 timestep to fix
applications = copy.copy(APPLICATION_TYPES_MAPPING)
applications.pop("DatabaseClient")
# remove test applications from list
applications = set(Application._application_registry) - set(TestApplications)
applications.remove("DatabaseClient")
for applications in applications:
assert client_1.software_manager.software.get(applications) is not None
assert client_1.software_manager.software.get(applications).fixing_duration == 2

View File

@@ -0,0 +1,54 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.service import ServiceOperatingState
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
return (game, agent)
def test_application_cannot_perform_actions_unless_running(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test the the request permissions prevent any actions unless application is running."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
browser: WebBrowser = client_1.software_manager.software.get("WebBrowser")
browser.close()
assert browser.operating_state == ApplicationOperatingState.CLOSED
action = ("NODE_APPLICATION_SCAN", {"node_id": 0, "application_id": 0})
agent.store_action(action)
game.step()
assert browser.operating_state == ApplicationOperatingState.CLOSED
action = ("NODE_APPLICATION_CLOSE", {"node_id": 0, "application_id": 0})
agent.store_action(action)
game.step()
assert browser.operating_state == ApplicationOperatingState.CLOSED
action = ("NODE_APPLICATION_FIX", {"node_id": 0, "application_id": 0})
agent.store_action(action)
game.step()
assert browser.operating_state == ApplicationOperatingState.CLOSED
action = ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0})
agent.store_action(action)
game.step()
assert browser.operating_state == ApplicationOperatingState.CLOSED

View File

@@ -99,7 +99,7 @@ class TestConfigureDatabaseAction:
game.step()
assert db_client.server_ip_address == old_ip
assert db_client.server_password is "admin123"
assert db_client.server_password == "admin123"
class TestConfigureRansomwareScriptAction:

View File

@@ -0,0 +1,159 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import uuid
from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
return (game, agent)
def test_create_file(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator allows a files to be created."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
random_folder = str(uuid.uuid4())
random_file = str(uuid.uuid4())
assert client_1.file_system.get_file(folder_name=random_folder, file_name=random_file) is None
action = (
"NODE_FILE_CREATE",
{"node_id": 0, "folder_name": random_folder, "file_name": random_file},
)
agent.store_action(action)
game.step()
assert client_1.file_system.get_file(folder_name=random_folder, file_name=random_file) is not None
def test_file_delete_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator allows a file to be deleted."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png")
assert file.deleted is False
action = (
"NODE_FILE_DELETE",
{"node_id": 0, "folder_id": 0, "file_id": 0},
)
agent.store_action(action)
game.step()
assert file.deleted
def test_file_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator allows a file to be scanned."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png")
file.corrupt()
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
assert file.visible_health_status == FileSystemItemHealthStatus.GOOD
action = (
"NODE_FILE_SCAN",
{"node_id": 0, "folder_id": 0, "file_id": 0},
)
agent.store_action(action)
game.step()
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT
def test_file_repair_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator allows a folder to be created."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png")
file.corrupt()
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
action = (
"NODE_FILE_REPAIR",
{"node_id": 0, "folder_id": 0, "file_id": 0},
)
agent.store_action(action)
game.step()
assert file.health_status == FileSystemItemHealthStatus.GOOD
def test_file_restore_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator allows a file to be restored."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png")
file.corrupt()
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
action = (
"NODE_FILE_RESTORE",
{"node_id": 0, "folder_id": 0, "file_id": 0},
)
agent.store_action(action)
game.step()
assert file.health_status == FileSystemItemHealthStatus.GOOD
def test_file_corrupt_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator allows a file to be corrupted."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png")
assert file.health_status == FileSystemItemHealthStatus.GOOD
action = (
"NODE_FILE_CORRUPT",
{"node_id": 0, "folder_id": 0, "file_id": 0},
)
agent.store_action(action)
game.step()
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
def test_file_access_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator allows a file to be accessed."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
file = client_1.file_system.get_file(folder_name="downloads", file_name="cat.png")
assert file.num_access == 0
action = (
"NODE_FILE_ACCESS",
{"node_id": 0, "folder_name": file.folder_name, "file_name": file.name},
)
agent.store_action(action)
game.step()
assert file.num_access == 1

View File

@@ -0,0 +1,123 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import uuid
from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
return (game, agent)
def test_create_folder(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator allows a folder to be created."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
random_folder = str(uuid.uuid4())
assert client_1.file_system.get_folder(folder_name=random_folder) is None
action = (
"NODE_FOLDER_CREATE",
{
"node_id": 0,
"folder_name": random_folder,
},
)
agent.store_action(action)
game.step()
assert client_1.file_system.get_folder(folder_name=random_folder) is not None
def test_folder_scan_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test to make sure that the validator checks if the folder exists before scanning."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
folder = client_1.file_system.get_folder(folder_name="downloads")
assert folder.health_status == FileSystemItemHealthStatus.GOOD
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
folder.corrupt()
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
assert folder.visible_health_status == FileSystemItemHealthStatus.GOOD
action = (
"NODE_FOLDER_SCAN",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads
},
)
agent.store_action(action)
game.step()
for i in range(folder.scan_duration + 1):
game.step()
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
assert folder.visible_health_status == FileSystemItemHealthStatus.CORRUPT
def test_folder_repair_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test to make sure that the validator checks if the folder exists before repairing."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
folder = client_1.file_system.get_folder(folder_name="downloads")
folder.corrupt()
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
action = (
"NODE_FOLDER_REPAIR",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads
},
)
agent.store_action(action)
game.step()
assert folder.health_status == FileSystemItemHealthStatus.GOOD
def test_folder_restore_action(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test to make sure that the validator checks if the folder exists before restoring."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
folder = client_1.file_system.get_folder(folder_name="downloads")
folder.corrupt()
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
action = (
"NODE_FOLDER_RESTORE",
{
"node_id": 0, # client_1,
"folder_id": 0, # downloads
},
)
agent.store_action(action)
game.step()
assert folder.health_status == FileSystemItemHealthStatus.RESTORING

View File

@@ -0,0 +1,95 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
return (game, agent)
def test_nic_cannot_be_turned_off_if_not_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that a NIC cannot be disabled if it is not enabled."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
nic = client_1.network_interface[1]
nic.disable()
assert nic.enabled is False
action = (
"HOST_NIC_DISABLE",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
},
)
agent.store_action(action)
game.step()
assert nic.enabled is False
def test_nic_cannot_be_turned_on_if_already_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that a NIC cannot be enabled if it is already enabled."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
nic = client_1.network_interface[1]
assert nic.enabled
action = (
"HOST_NIC_ENABLE",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
},
)
agent.store_action(action)
game.step()
assert nic.enabled
def test_that_a_nic_can_be_enabled_and_disabled(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Tests that a NIC can be enabled and disabled."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
nic = client_1.network_interface[1]
assert nic.enabled
action = (
"HOST_NIC_DISABLE",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
},
)
agent.store_action(action)
game.step()
assert nic.enabled is False
action = (
"HOST_NIC_ENABLE",
{
"node_id": 0, # client_1
"nic_id": 0, # the only nic (eth-1)
},
)
agent.store_action(action)
game.step()
assert nic.enabled

View File

@@ -0,0 +1,94 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
return (game, agent)
def test_node_startup_shutdown(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the node can be shut down and started up."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
assert client_1.operating_state == NodeOperatingState.ON
# turn it off
action = ("NODE_SHUTDOWN", {"node_id": 0})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.SHUTTING_DOWN
for i in range(client_1.shut_down_duration + 1):
action = ("DONOTHING", {"node_id": 0})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.OFF
# turn it on
action = ("NODE_STARTUP", {"node_id": 0})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.BOOTING
for i in range(client_1.start_up_duration + 1):
action = ("DONOTHING", {"node_id": 0})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.ON
def test_node_cannot_be_started_up_if_node_is_already_on(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that a node cannot be started up if it is already on."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
assert client_1.operating_state == NodeOperatingState.ON
# turn it on
action = ("NODE_STARTUP", {"node_id": 0})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.ON
def test_node_cannot_be_shut_down_if_node_is_already_off(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that a node cannot be shut down if it is already off."""
game, agent = game_and_agent_fixture
client_1 = game.simulation.network.get_node_by_hostname("client_1")
client_1.power_off()
for i in range(client_1.shut_down_duration + 1):
action = ("DONOTHING", {"node_id": 0})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.OFF
# turn it ff
action = ("NODE_SHUTDOWN", {"node_id": 0})
agent.store_action(action)
game.step()
assert client_1.operating_state == NodeOperatingState.OFF

View File

@@ -0,0 +1,106 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from typing import Tuple
import pytest
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.hardware.nodes.host.computer import Computer
from primaite.simulator.network.hardware.nodes.host.server import Server
from primaite.simulator.system.services.service import ServiceOperatingState
@pytest.fixture
def game_and_agent_fixture(game_and_agent):
"""Create a game with a simple agent that can be controlled by the tests."""
game, agent = game_and_agent
client_1: Computer = game.simulation.network.get_node_by_hostname("client_1")
client_1.start_up_duration = 3
return (game, agent)
def test_service_start(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator makes sure that the service is stopped before starting the service."""
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
dns_server = server_1.software_manager.software.get("DNSServer")
dns_server.pause()
assert dns_server.operating_state == ServiceOperatingState.PAUSED
action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.PAUSED
dns_server.stop()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_START", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.RUNNING
def test_service_resume(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test that the validator checks if the service is paused before resuming."""
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
dns_server = server_1.software_manager.software.get("DNSServer")
action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.RUNNING
dns_server.pause()
assert dns_server.operating_state == ServiceOperatingState.PAUSED
action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.RUNNING
def test_service_cannot_perform_actions_unless_running(game_and_agent_fixture: Tuple[PrimaiteGame, ProxyAgent]):
"""Test to make sure that the service cannot perform certain actions while not running."""
game, agent = game_and_agent_fixture
server_1: Server = game.simulation.network.get_node_by_hostname("server_1")
dns_server = server_1.software_manager.software.get("DNSServer")
dns_server.stop()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_SCAN", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_PAUSE", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_RESUME", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_RESTART", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED
action = ("NODE_SERVICE_FIX", {"node_id": 1, "service_id": 0})
agent.store_action(action)
game.step()
assert dns_server.operating_state == ServiceOperatingState.STOPPED

View File

@@ -155,7 +155,7 @@ def test_nic_monitored_traffic(simulation):
assert traffic_obs["icmp"]["outbound"] == 0
# send a ping
pc.ping(target_ip_address=pc2.network_interface[1].ip_address)
assert pc.ping(target_ip_address=pc2.network_interface[1].ip_address)
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
assert traffic_obs["icmp"]["inbound"] == 1
@@ -178,7 +178,7 @@ def test_nic_monitored_traffic(simulation):
traffic_obs = nic_obs.observe(simulation.describe_state()).get("TRAFFIC")
assert traffic_obs["icmp"]["inbound"] == 0
assert traffic_obs["icmp"]["outbound"] == 0
assert traffic_obs["tcp"][53]["inbound"] == 0
assert traffic_obs["tcp"][53]["inbound"] == 1
assert traffic_obs["tcp"][53]["outbound"] == 1 # getting a webpage sent a dns request out
simulation.pre_timestep(2) # apply timestep to whole sim

View File

@@ -0,0 +1,161 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.session.environment import PrimaiteGymEnv
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.services.service import ServiceOperatingState
from tests.conftest import TEST_ASSETS_ROOT
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
def test_mask_contents_correct():
env = PrimaiteGymEnv(CFG_PATH)
game = env.game
sim = game.simulation
net = sim.network
mask = game.action_mask("defender")
agent = env.agent
node_list = agent.action_manager.node_names
action_map = agent.action_manager.action_map
# CHECK NIC ENABLE/DISABLE ACTIONS
for action_num, action in action_map.items():
mask = game.action_mask("defender")
act_type, act_params = action
if act_type == "NODE_NIC_ENABLE":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
assert nic_obj.enabled
assert not mask[action_num]
nic_obj.disable()
mask = game.action_mask("defender")
assert mask[action_num]
nic_obj.enable()
if act_type == "NODE_NIC_DISABLE":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
nic_obj = node_obj.network_interface[act_params["nic_id"] + 1]
assert nic_obj.enabled
assert mask[action_num]
nic_obj.disable()
mask = game.action_mask("defender")
assert not mask[action_num]
nic_obj.enable()
if act_type == "ROUTER_ACL_ADDRULE":
assert mask[action_num]
if act_type == "ROUTER_ACL_REMOVERULE":
assert mask[action_num]
if act_type == "NODE_RESET":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.OFF
mask = game.action_mask("defender")
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_SHUTDOWN":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.OFF
mask = game.action_mask("defender")
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_OS_SCAN":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.OFF
mask = game.action_mask("defender")
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "NODE_STARTUP":
node_name = node_list[act_params["node_id"]]
node_obj = net.get_node_by_hostname(node_name)
assert node_obj.operating_state is NodeOperatingState.ON
assert not mask[action_num]
node_obj.operating_state = NodeOperatingState.OFF
mask = game.action_mask("defender")
assert mask[action_num]
node_obj.operating_state = NodeOperatingState.ON
if act_type == "DONOTHING":
assert mask[action_num]
if act_type == "NODE_SERVICE_DISABLE":
assert mask[action_num]
if act_type in ["NODE_SERVICE_SCAN", "NODE_SERVICE_STOP", "NODE_SERVICE_PAUSE"]:
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.DISABLED
mask = game.action_mask("defender")
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_RESUME":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.PAUSED
mask = game.action_mask("defender")
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_START":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.STOPPED
mask = game.action_mask("defender")
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type == "NODE_SERVICE_ENABLE":
node_name = node_list[act_params["node_id"]]
service_name = agent.action_manager.service_names[act_params["node_id"]][act_params["service_id"]]
node_obj = net.get_node_by_hostname(node_name)
service_obj = node_obj.software_manager.software.get(service_name)
assert service_obj.operating_state is ServiceOperatingState.RUNNING
assert not mask[action_num]
service_obj.operating_state = ServiceOperatingState.DISABLED
mask = game.action_mask("defender")
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING
if act_type in ["NODE_FILE_SCAN", "NODE_FILE_CHECKHASH", "NODE_FILE_DELETE"]:
node_name = node_list[act_params["node_id"]]
folder_name = agent.action_manager.get_folder_name_by_idx(act_params["node_id"], act_params["folder_id"])
file_name = agent.action_manager.get_file_name_by_idx(
act_params["node_id"], act_params["folder_id"], act_params["file_id"]
)
node_obj = net.get_node_by_hostname(node_name)
file_obj = node_obj.file_system.get_file(folder_name, file_name, include_deleted=True)
assert not file_obj.deleted
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.DISABLED
mask = game.action_mask("defender")
assert mask[action_num]
service_obj.operating_state = ServiceOperatingState.RUNNING

View File

@@ -0,0 +1,44 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import yaml
from primaite.game.game import PrimaiteGame
from primaite.simulator.network.airspace import AirSpaceFrequency
from tests import TEST_ASSETS_ROOT
def test_override_freq_max_capacity_mbps():
config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_network_config_freq_max_override.yaml"
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)
network = PrimaiteGame.from_config(cfg=config_dict).simulation.network
assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 123.45
assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0
pc_a = network.get_node_by_hostname("pc_a")
pc_b = network.get_node_by_hostname("pc_b")
assert pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should be able to ping PC B"
assert pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should be able to ping PC A"
network.airspace.show()
def test_override_freq_max_capacity_mbps_blocked():
config_path = TEST_ASSETS_ROOT / "configs" / "wireless_wan_network_config_freq_max_override_blocked.yaml"
with open(config_path, "r") as f:
config_dict = yaml.safe_load(f)
network = PrimaiteGame.from_config(cfg=config_dict).simulation.network
assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_2_4) == 0.0
assert network.airspace.get_frequency_max_capacity_mbps(AirSpaceFrequency.WIFI_5) == 0.0
pc_a = network.get_node_by_hostname("pc_a")
pc_b = network.get_node_by_hostname("pc_b")
assert not pc_a.ping(pc_b.network_interface[1].ip_address), "PC A should not be able to ping PC B"
assert not pc_b.ping(pc_a.network_interface[1].ip_address), "PC B should not be able to ping PC A"
network.airspace.show()

View File

@@ -0,0 +1,114 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.simulator.file_system.file_type import FileType
from primaite.simulator.network.hardware.nodes.network.router import ACLAction
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from tests.integration_tests.network.test_wireless_router import wireless_wan_network
from tests.integration_tests.system.test_ftp_client_server import ftp_client_and_ftp_server
def test_wireless_link_loading(wireless_wan_network):
client, server, router_1, router_2 = wireless_wan_network
# Configure Router 1 ACLs
router_1.acl.add_rule(action=ACLAction.PERMIT, position=1)
# Configure Router 2 ACLs
router_2.acl.add_rule(action=ACLAction.PERMIT, position=1)
airspace = router_1.airspace
client.software_manager.install(FTPClient)
ftp_client: FTPClient = client.software_manager.software.get("FTPClient")
ftp_client.start()
server.software_manager.install(FTPServer)
ftp_server: FTPServer = server.software_manager.software.get("FTPServer")
ftp_server.start()
client.file_system.create_file(file_name="mixtape", size=10 * 10**6, file_type=FileType.MP3, folder_name="music")
assert ftp_client.send_file(
src_file_name="mixtape.mp3",
src_folder_name="music",
dest_ip_address=server.network_interface[1].ip_address,
dest_file_name="mixtape.mp3",
dest_folder_name="music",
)
# Reset the physical links between the host nodes and the routers
client.network_interface[1]._connected_link.pre_timestep(1)
server.network_interface[1]._connected_link.pre_timestep(1)
assert not ftp_client.send_file(
src_file_name="mixtape.mp3",
src_folder_name="music",
dest_ip_address=server.network_interface[1].ip_address,
dest_file_name="mixtape3.mp3",
dest_folder_name="music",
)
# Reset the physical links between the host nodes and the routers
client.network_interface[1]._connected_link.pre_timestep(1)
server.network_interface[1]._connected_link.pre_timestep(1)
airspace.reset_bandwidth_load()
assert ftp_client.send_file(
src_file_name="mixtape.mp3",
src_folder_name="music",
dest_ip_address=server.network_interface[1].ip_address,
dest_file_name="mixtape3.mp3",
dest_folder_name="music",
)
def test_wired_link_loading(ftp_client_and_ftp_server):
ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
link = computer.network_interface[1]._connected_link # noqa
assert link.is_up
link.pre_timestep(1)
computer.file_system.create_file(
file_name="mixtape", size=10 * 10**6, file_type=FileType.MP3, folder_name="music"
)
link_load = link.current_load
assert link_load == 0.0
assert ftp_client.send_file(
src_file_name="mixtape.mp3",
src_folder_name="music",
dest_ip_address=server.network_interface[1].ip_address,
dest_file_name="mixtape.mp3",
dest_folder_name="music",
)
new_link_load = link.current_load
assert new_link_load > link_load
assert not ftp_client.send_file(
src_file_name="mixtape.mp3",
src_folder_name="music",
dest_ip_address=server.network_interface[1].ip_address,
dest_file_name="mixtape1.mp3",
dest_folder_name="music",
)
link.pre_timestep(2)
link_load = link.current_load
assert link_load == 0.0
assert ftp_client.send_file(
src_file_name="mixtape.mp3",
src_folder_name="music",
dest_ip_address=server.network_interface[1].ip_address,
dest_file_name="mixtape1.mp3",
dest_folder_name="music",
)
new_link_load = link.current_load
assert new_link_load > link_load

View File

@@ -14,7 +14,7 @@ from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
class BroadcastService(Service):
class BroadcastTestService(Service):
"""A service for sending broadcast and unicast messages over a network."""
def __init__(self, **kwargs):
@@ -41,14 +41,14 @@ class BroadcastService(Service):
super().send(payload="broadcast", dest_ip_address=ip_network, dest_port=Port.HTTP, ip_protocol=self.protocol)
class BroadcastClient(Application, identifier="BroadcastClient"):
class BroadcastTestClient(Application, identifier="BroadcastTestClient"):
"""A client application to receive broadcast and unicast messages."""
payloads_received: List = []
def __init__(self, **kwargs):
# Set default client properties
kwargs["name"] = "BroadcastClient"
kwargs["name"] = "BroadcastTestClient"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
@@ -75,8 +75,8 @@ def broadcast_network() -> Network:
start_up_duration=0,
)
client_1.power_on()
client_1.software_manager.install(BroadcastClient)
application_1 = client_1.software_manager.software["BroadcastClient"]
client_1.software_manager.install(BroadcastTestClient)
application_1 = client_1.software_manager.software["BroadcastTestClient"]
application_1.run()
client_2 = Computer(
@@ -87,8 +87,8 @@ def broadcast_network() -> Network:
start_up_duration=0,
)
client_2.power_on()
client_2.software_manager.install(BroadcastClient)
application_2 = client_2.software_manager.software["BroadcastClient"]
client_2.software_manager.install(BroadcastTestClient)
application_2 = client_2.software_manager.software["BroadcastTestClient"]
application_2.run()
server_1 = Server(
@@ -100,8 +100,8 @@ def broadcast_network() -> Network:
)
server_1.power_on()
server_1.software_manager.install(BroadcastService)
service: BroadcastService = server_1.software_manager.software["BroadcastService"]
server_1.software_manager.install(BroadcastTestService)
service: BroadcastTestService = server_1.software_manager.software["BroadcastService"]
service.start()
switch_1 = Switch(hostname="switch_1", num_ports=6, start_up_duration=0)
@@ -115,14 +115,16 @@ def broadcast_network() -> Network:
@pytest.fixture(scope="function")
def broadcast_service_and_clients(broadcast_network) -> Tuple[BroadcastService, BroadcastClient, BroadcastClient]:
client_1: BroadcastClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
"BroadcastClient"
def broadcast_service_and_clients(
broadcast_network,
) -> Tuple[BroadcastTestService, BroadcastTestClient, BroadcastTestClient]:
client_1: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_1").software_manager.software[
"BroadcastTestClient"
]
client_2: BroadcastClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
"BroadcastClient"
client_2: BroadcastTestClient = broadcast_network.get_node_by_hostname("client_2").software_manager.software[
"BroadcastTestClient"
]
service: BroadcastService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
service: BroadcastTestService = broadcast_network.get_node_by_hostname("server_1").software_manager.software[
"BroadcastService"
]

View File

@@ -101,6 +101,7 @@ def test_port_scan_full_subnet_all_ports_and_protocols(example_network):
actual_result = client_1_nmap.port_scan(
target_ip_address=IPv4Network("192.168.10.0/24"),
target_port=[Port.ARP, Port.HTTP, Port.FTP, Port.DNS, Port.NTP],
)
expected_result = {

View File

@@ -0,0 +1,137 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from uuid import uuid4
import pytest
from primaite import PRIMAITE_CONFIG
from primaite.game.agent.agent_log import AgentLog
from primaite.simulator import LogLevel, SIM_OUTPUT
@pytest.fixture(autouse=True)
def override_dev_mode_temporarily():
"""Temporarily turn off dev mode for this test."""
primaite_dev_mode = PRIMAITE_CONFIG["developer_mode"]["enabled"]
PRIMAITE_CONFIG["developer_mode"]["enabled"] = False
yield # run tests
PRIMAITE_CONFIG["developer_mode"]["enabled"] = primaite_dev_mode
@pytest.fixture(scope="function")
def agentlog() -> AgentLog:
return AgentLog(agent_name="test_agent")
def test_debug_agent_log_level(agentlog, capsys):
"""Test that the debug log level logs debug agent logs and above."""
SIM_OUTPUT.agent_log_level = LogLevel.DEBUG
SIM_OUTPUT.write_agent_log_to_terminal = True
test_string = str(uuid4())
agentlog.debug(msg=test_string)
agentlog.info(msg=test_string)
agentlog.warning(msg=test_string)
agentlog.error(msg=test_string)
agentlog.critical(msg=test_string)
captured = "".join(capsys.readouterr())
assert test_string in captured
assert "DEBUG" in captured
assert "INFO" in captured
assert "WARNING" in captured
assert "ERROR" in captured
assert "CRITICAL" in captured
def test_info_agent_log_level(agentlog, capsys):
"""Test that the debug log level logs debug agent logs and above."""
SIM_OUTPUT.agent_log_level = LogLevel.INFO
SIM_OUTPUT.write_agent_log_to_terminal = True
test_string = str(uuid4())
agentlog.debug(msg=test_string)
agentlog.info(msg=test_string)
agentlog.warning(msg=test_string)
agentlog.error(msg=test_string)
agentlog.critical(msg=test_string)
captured = "".join(capsys.readouterr())
assert test_string in captured
assert "DEBUG" not in captured
assert "INFO" in captured
assert "WARNING" in captured
assert "ERROR" in captured
assert "CRITICAL" in captured
def test_warning_agent_log_level(agentlog, capsys):
"""Test that the debug log level logs debug agent logs and above."""
SIM_OUTPUT.agent_log_level = LogLevel.WARNING
SIM_OUTPUT.write_agent_log_to_terminal = True
test_string = str(uuid4())
agentlog.debug(msg=test_string)
agentlog.info(msg=test_string)
agentlog.warning(msg=test_string)
agentlog.error(msg=test_string)
agentlog.critical(msg=test_string)
captured = "".join(capsys.readouterr())
assert test_string in captured
assert "DEBUG" not in captured
assert "INFO" not in captured
assert "WARNING" in captured
assert "ERROR" in captured
assert "CRITICAL" in captured
def test_error_agent_log_level(agentlog, capsys):
"""Test that the debug log level logs debug agent logs and above."""
SIM_OUTPUT.agent_log_level = LogLevel.ERROR
SIM_OUTPUT.write_agent_log_to_terminal = True
test_string = str(uuid4())
agentlog.debug(msg=test_string)
agentlog.info(msg=test_string)
agentlog.warning(msg=test_string)
agentlog.error(msg=test_string)
agentlog.critical(msg=test_string)
captured = "".join(capsys.readouterr())
assert test_string in captured
assert "DEBUG" not in captured
assert "INFO" not in captured
assert "WARNING" not in captured
assert "ERROR" in captured
assert "CRITICAL" in captured
def test_critical_agent_log_level(agentlog, capsys):
"""Test that the debug log level logs debug agent logs and above."""
SIM_OUTPUT.agent_log_level = LogLevel.CRITICAL
SIM_OUTPUT.write_agent_log_to_terminal = True
test_string = str(uuid4())
agentlog.debug(msg=test_string)
agentlog.info(msg=test_string)
agentlog.warning(msg=test_string)
agentlog.error(msg=test_string)
agentlog.critical(msg=test_string)
captured = "".join(capsys.readouterr())
assert test_string in captured
assert "DEBUG" not in captured
assert "INFO" not in captured
assert "WARNING" not in captured
assert "ERROR" not in captured
assert "CRITICAL" in captured

View File

@@ -26,7 +26,7 @@ def test_file_scan_request(populated_file_system):
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
assert file.visible_health_status == FileSystemItemHealthStatus.GOOD
fs.apply_request(request=["file", file.name, "scan"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "scan"])
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
assert file.visible_health_status == FileSystemItemHealthStatus.CORRUPT
@@ -37,12 +37,12 @@ def test_file_checkhash_request(populated_file_system):
"""Test that an agent can request a file hash check."""
fs, folder, file = populated_file_system
fs.apply_request(request=["file", file.name, "checkhash"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "checkhash"])
assert file.health_status == FileSystemItemHealthStatus.GOOD
file.sim_size = 0
fs.apply_request(request=["file", file.name, "checkhash"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "checkhash"])
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
@@ -54,7 +54,7 @@ def test_file_repair_request(populated_file_system):
file.corrupt()
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
fs.apply_request(request=["file", file.name, "repair"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "repair"])
assert file.health_status == FileSystemItemHealthStatus.GOOD
@@ -71,7 +71,7 @@ def test_file_restore_request(populated_file_system):
assert fs.get_file(folder_name=folder.name, file_name=file.name) is not None
assert fs.get_file(folder_name=folder.name, file_name=file.name).deleted is False
fs.apply_request(request=["file", file.name, "corrupt"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "corrupt"])
assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT
fs.apply_request(request=["restore", "file", folder.name, file.name])
@@ -81,7 +81,7 @@ def test_file_restore_request(populated_file_system):
def test_file_corrupt_request(populated_file_system):
"""Test that an agent can request a file corruption."""
fs, folder, file = populated_file_system
fs.apply_request(request=["file", file.name, "corrupt"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "corrupt"])
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
@@ -90,7 +90,7 @@ def test_deleted_file_cannot_be_interacted_with(populated_file_system):
fs, folder, file = populated_file_system
assert fs.get_file(folder_name=folder.name, file_name=file.name) is not None
fs.apply_request(request=["file", file.name, "corrupt"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "corrupt"])
assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT
assert (
fs.get_file(folder_name=folder.name, file_name=file.name).visible_health_status

View File

@@ -39,3 +39,39 @@ def test_folder_delete_request(populated_file_system):
assert fs.get_file_by_id(folder_uuid=folder.uuid, file_uuid=file.uuid) is None
fs.show(full=True)
def test_folder_exists_request_validator(populated_file_system):
"""Tests that the _FolderExistsValidator works as intended."""
fs, folder, file = populated_file_system
validator = FileSystem._FolderExistsValidator(file_system=fs)
assert validator(request=["test_folder"], context={}) # test_folder exists
assert validator(request=["fake_folder"], context={}) is False # fake_folder does not exist
assert validator.fail_message == "Cannot perform request on folder because it does not exist."
def test_file_exists_request_validator(populated_file_system):
"""Tests that the _FolderExistsValidator works as intended."""
fs, folder, file = populated_file_system
validator = FileSystem._FileExistsValidator(file_system=fs)
assert validator(request=["test_folder", "test_file.txt"], context={}) # test_file.txt exists
assert validator(request=["test_folder", "fake_file.txt"], context={}) is False # fake_file.txt does not exist
assert validator.fail_message == "Cannot perform request on a file that does not exist."
def test_folder_not_deleted_request_validator(populated_file_system):
"""Tests that the _FolderExistsValidator works as intended."""
fs, folder, file = populated_file_system
validator = FileSystem._FolderNotDeletedValidator(file_system=fs)
assert validator(request=["test_folder"], context={}) # test_folder is not deleted
fs.delete_folder(folder_name="test_folder")
assert validator(request=["test_folder"], context={}) is False # test_folder is deleted
assert validator.fail_message == "Cannot perform request on folder because it is deleted."

View File

@@ -166,15 +166,40 @@ def test_deleted_folder_and_its_files_cannot_be_interacted_with(populated_file_s
fs, folder, file = populated_file_system
assert fs.get_file(folder_name=folder.name, file_name=file.name) is not None
fs.apply_request(request=["file", file.name, "corrupt"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "corrupt"])
assert fs.get_file(folder_name=folder.name, file_name=file.name).health_status == FileSystemItemHealthStatus.CORRUPT
fs.apply_request(request=["delete", "folder", folder.name])
assert fs.get_file(folder_name=folder.name, file_name=file.name) is None
fs.apply_request(request=["file", file.name, "repair"])
fs.apply_request(request=["folder", folder.name, "file", file.name, "repair"])
deleted_folder = fs.deleted_folders.get(folder.uuid)
deleted_file = deleted_folder.deleted_files.get(file.uuid)
assert deleted_file.health_status is not FileSystemItemHealthStatus.GOOD
def test_file_exists_request_validator(populated_file_system):
"""Tests that the _FolderExistsValidator works as intended."""
fs, folder, file = populated_file_system
validator = Folder._FileExistsValidator(folder=folder)
assert validator(request=["test_file.txt"], context={}) # test_file.txt exists
assert validator(request=["fake_file.txt"], context={}) is False # fake_file.txt does not exist
assert validator.fail_message == "Cannot perform request on a file that does not exist."
def test_file_not_deleted_request_validator(populated_file_system):
"""Tests that the _FolderExistsValidator works as intended."""
fs, folder, file = populated_file_system
validator = Folder._FileNotDeletedValidator(folder=folder)
assert validator(request=["test_file.txt"], context={}) # test_file.txt is not deleted
fs.delete_file(folder_name="test_folder", file_name="test_file.txt")
assert validator(request=["fake_file.txt"], context={}) is False # test_file.txt is deleted
assert validator.fail_message == "Cannot perform request on a file that is deleted."

View File

@@ -0,0 +1,34 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
import pytest
from primaite.simulator.network.hardware.base import NetworkInterface, Node
from primaite.simulator.network.hardware.nodes.host.computer import Computer
@pytest.fixture
def node() -> Node:
return Computer(hostname="test", ip_address="192.168.1.2", subnet_mask="255.255.255.0")
def test_nic_enabled_validator(node):
"""Test the NetworkInterface enabled validator."""
network_interface = node.network_interface[1]
validator = NetworkInterface._EnabledValidator(network_interface=network_interface)
assert validator(request=[], context={}) is False # not enabled
network_interface.enabled = True
assert validator(request=[], context={}) # enabled
def test_nic_disabled_validator(node):
"""Test the NetworkInterface enabled validator."""
network_interface = node.network_interface[1]
validator = NetworkInterface._DisabledValidator(network_interface=network_interface)
assert validator(request=[], context={}) # not enabled
network_interface.enabled = True
assert validator(request=[], context={}) is False # enabled

View File

@@ -155,3 +155,39 @@ def test_reset_node(node):
assert node.operating_state == NodeOperatingState.BOOTING
assert node.operating_state == NodeOperatingState.ON
def test_node_is_on_validator(node):
"""Test that the node is on validator."""
node.power_on()
for i in range(node.start_up_duration + 1):
node.apply_timestep(i)
validator = Node._NodeIsOnValidator(node=node)
assert validator(request=[], context={})
node.power_off()
for i in range(node.shut_down_duration + 1):
node.apply_timestep(i)
assert validator(request=[], context={}) is False
def test_node_is_off_validator(node):
"""Test that the node is on validator."""
node.power_on()
for i in range(node.start_up_duration + 1):
node.apply_timestep(i)
validator = Node._NodeIsOffValidator(node=node)
assert validator(request=[], context={}) is False
node.power_off()
for i in range(node.shut_down_duration + 1):
node.apply_timestep(i)
assert validator(request=[], context={})

View File

@@ -1 +1,15 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
def test_application_state_validator(application):
"""Test the application state validator."""
validator = Application._StateValidator(application=application, state=ApplicationOperatingState.CLOSED)
assert validator(request=[], context={}) # application is closed
application.run()
assert validator(request=[], context={}) is False # application is running - expecting closed
validator = Application._StateValidator(application=application, state=ApplicationOperatingState.RUNNING)
assert validator(request=[], context={}) # application is running
application.close()
assert validator(request=[], context={}) is False # application is closed - expecting running

View File

@@ -1,5 +1,5 @@
# © Crown-owned copyright 2024, Defence Science and Technology Laboratory UK
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
@@ -92,3 +92,21 @@ def test_service_fix(service):
assert service.health_state_actual == SoftwareHealthState.FIXING
service.apply_timestep(2)
assert service.health_state_actual == SoftwareHealthState.GOOD
def test_service_state_validator(service):
"""Test the service state validator."""
validator = Service._StateValidator(service=service, state=ServiceOperatingState.STOPPED)
assert validator(request=[], context={}) # service is stopped
service.start()
assert validator(request=[], context={}) is False # service is running - expecting stopped
validator = Service._StateValidator(service=service, state=ServiceOperatingState.RUNNING)
assert validator(request=[], context={}) # service is running
service.pause()
assert validator(request=[], context={}) is False # service is paused - expecting running
validator = Service._StateValidator(service=service, state=ServiceOperatingState.PAUSED)
assert validator(request=[], context={}) # service is paused
service.resume()
assert validator(request=[], context={}) is False # service is running - expecting paused