Merge branch 'dev' into bugfix/2299-check_hash_function_corrupts_files_and_folders
This commit is contained in:
@@ -3,6 +3,7 @@ repos:
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
exclude: scenario_with_placeholders/
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- id: check-added-large-files
|
||||
|
||||
11
CHANGELOG.md
11
CHANGELOG.md
@@ -7,9 +7,20 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## 3.0.0b9
|
||||
- Removed deprecated `PrimaiteSession` class.
|
||||
- Added ability to set log levels via configuration.
|
||||
- Upgraded pydantic to version 2.7.0
|
||||
- Upgraded Ray to version >= 2.9
|
||||
- Added ipywidgets to the dependencies
|
||||
- Added ability to define scenarios that change depending on the episode number.
|
||||
- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config`
|
||||
- Database Connection ID's are now created/issued by DatabaseService and not DatabaseClient
|
||||
- added ability to set PrimAITE between development and production modes via PrimAITE CLI ``mode`` command
|
||||
- Updated DatabaseClient so that it can now have a single native DatabaseClientConnection along with a collection of DatabaseClientConnection's.
|
||||
- Implemented the uninstall functionality for DatabaseClient so that all connections are terminated at the DatabaseService.
|
||||
- Added the ability for a DatabaseService to terminate a connection.
|
||||
- Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used.
|
||||
- Added additional show functions to enable connection inspection.
|
||||
|
||||
|
||||
## [Unreleased]
|
||||
- Made requests fail to reach their target if the node is off
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
PrimAITE |VERSION| Configuration
|
||||
********************************
|
||||
|
||||
PrimAITE uses a single configuration file to define everything needed to create the training environment for RL agents, including the network, the scripted agents, and the RL agent's action space, observation space, and reward function.
|
||||
PrimAITE uses YAML configuration files to define everything needed to create the training environment for RL agents, including the network, the scripted agents, and the RL agent's action space, observation space, and reward function.
|
||||
|
||||
Example Configuration Hierarchy
|
||||
###############################
|
||||
@@ -34,3 +34,8 @@ Configurable items
|
||||
configuration/game.rst
|
||||
configuration/agents.rst
|
||||
configuration/simulation.rst
|
||||
|
||||
Varying The Configuration Each Episode
|
||||
######################################
|
||||
|
||||
PrimAITE allows for the configuration to be varied each episode. This is done by specifying a configuration folder instead of a single file. A full explanation is provided in the notebook `Using-Episode-Schedules.ipynb`. Please find the notebook in the user notebooks directory.
|
||||
|
||||
@@ -18,6 +18,8 @@ This section configures how PrimAITE saves data during simulation and training.
|
||||
save_step_metadata: False
|
||||
save_pcap_logs: False
|
||||
save_sys_logs: False
|
||||
write_sys_log_to_terminal: False
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
``save_logs``
|
||||
@@ -25,7 +27,6 @@ This section configures how PrimAITE saves data during simulation and training.
|
||||
|
||||
*currently unused*.
|
||||
|
||||
|
||||
``save_agent_actions``
|
||||
----------------------
|
||||
|
||||
@@ -55,3 +56,35 @@ If ``True``, then the pcap files which contain all network traffic during the si
|
||||
Optional. Default value is ``False``.
|
||||
|
||||
If ``True``, then the log files which contain all node actions during the simulation will be saved.
|
||||
|
||||
|
||||
``write_sys_log_to_terminal``
|
||||
-----------------------------
|
||||
|
||||
Optional. Default value is ``False``.
|
||||
|
||||
If ``True``, PrimAITE will print sys log to the terminal.
|
||||
|
||||
|
||||
``sys_log_level``
|
||||
-------------
|
||||
|
||||
Optional. Default value is ``WARNING``.
|
||||
|
||||
The level of logging that should be visible in the sys logs or the logs output to the terminal.
|
||||
|
||||
``save_sys_logs`` or ``write_sys_log_to_terminal`` has to be set to ``True`` for this setting to be used.
|
||||
|
||||
Available options are:
|
||||
|
||||
- ``DEBUG``: Debug level items and the items below
|
||||
- ``INFO``: Info level items and the items below
|
||||
- ``WARNING``: Warning level items and the items below
|
||||
- ``ERROR``: Error level items and the items below
|
||||
- ``CRITICAL``: Only critical level logs
|
||||
|
||||
See also |logging_levels|
|
||||
|
||||
.. |logging_levels| raw:: html
|
||||
|
||||
<a href="https://docs.python.org/3/library/logging.html#logging-levels" target="blank">Python logging levels</a>
|
||||
|
||||
@@ -141,3 +141,29 @@ of your choice:
|
||||
pip install -e .[dev]
|
||||
|
||||
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).
|
||||
|
||||
4. Set PrimAITE to run on development mode
|
||||
|
||||
Running step 3 should have installed PrimAITE, verify this by running
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
primaite setup
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
primaite setup
|
||||
|
||||
To set PrimAITE to run in development mode:
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
primaite mode --dev
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
primaite mode --dev
|
||||
|
||||
@@ -14,13 +14,14 @@ Key features
|
||||
|
||||
- Connects to the :ref:`DatabaseService` via the ``SoftwareManager``.
|
||||
- Handles connecting and disconnecting.
|
||||
- Handles multiple connections using a dictionary, mapped to connection UIDs
|
||||
- Executes SQL queries and retrieves result sets.
|
||||
|
||||
Usage
|
||||
=====
|
||||
|
||||
- Initialise with server IP address and optional password.
|
||||
- Connect to the :ref:`DatabaseService` with ``connect``.
|
||||
- Connect to the :ref:`DatabaseService` with ``get_new_connection``.
|
||||
- Retrieve results in a dictionary.
|
||||
- Disconnect when finished.
|
||||
|
||||
@@ -28,6 +29,7 @@ Implementation
|
||||
==============
|
||||
|
||||
- Leverages ``SoftwareManager`` for sending payloads over the network.
|
||||
- Active sessions are held as ``DatabaseClientConnection`` objects in a dictionary.
|
||||
- Connect and disconnect methods manage sessions.
|
||||
- Payloads serialised as dictionaries for transmission.
|
||||
- Extends base Application class.
|
||||
@@ -63,6 +65,9 @@ Python
|
||||
database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService
|
||||
database_client.run()
|
||||
|
||||
# Establish a new connection
|
||||
database_client.get_new_connection()
|
||||
|
||||
|
||||
Via Configuration
|
||||
"""""""""""""""""
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing_extensions import Annotated
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
app = typer.Typer()
|
||||
app = typer.Typer(no_args_is_help=True)
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -114,3 +114,47 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def mode(
|
||||
dev: Annotated[bool, typer.Option("--dev", help="Activates PrimAITE developer mode")] = None,
|
||||
prod: Annotated[bool, typer.Option("--prod", help="Activates PrimAITE production mode")] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Switch PrimAITE between developer mode and production mode.
|
||||
|
||||
By default, PrimAITE will be in production mode.
|
||||
|
||||
To view the current mode, use: primaite mode
|
||||
|
||||
To set to development mode, use: primaite mode --dev
|
||||
|
||||
To return to production mode, use: primaite mode --prod
|
||||
"""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
if dev and prod:
|
||||
print("Unable to activate developer and production modes concurrently.")
|
||||
return
|
||||
|
||||
if (dev is None) and (prod is None):
|
||||
is_dev_mode = primaite_config["developer_mode"]
|
||||
|
||||
if is_dev_mode:
|
||||
print("PrimAITE is running in developer mode.")
|
||||
else:
|
||||
print("PrimAITE is running in production mode.")
|
||||
if dev:
|
||||
# activate dev mode
|
||||
primaite_config["developer_mode"] = True
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print("PrimAITE is running in developer mode.")
|
||||
if prod:
|
||||
# activate prod mode
|
||||
primaite_config["developer_mode"] = False
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print("PrimAITE is running in production mode.")
|
||||
|
||||
@@ -2,7 +2,8 @@ io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: true
|
||||
save_sys_logs: false
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
game:
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# No green agents present
|
||||
greens: &greens []
|
||||
@@ -0,0 +1,34 @@
|
||||
agents: &greens
|
||||
- ref: green_A
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.2
|
||||
1: 0.8
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client
|
||||
@@ -0,0 +1,34 @@
|
||||
agents: &greens
|
||||
- ref: green_B
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.95
|
||||
1: 0.05
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client
|
||||
@@ -0,0 +1,2 @@
|
||||
# No red agents present
|
||||
reds: &reds []
|
||||
@@ -0,0 +1,26 @@
|
||||
reds: &reds
|
||||
- ref: red_A
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 10
|
||||
frequency: 10
|
||||
variance: 0
|
||||
@@ -0,0 +1,26 @@
|
||||
reds: &reds
|
||||
- ref: red_B
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 3
|
||||
frequency: 2
|
||||
variance: 1
|
||||
@@ -0,0 +1,168 @@
|
||||
io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: false
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 128
|
||||
ports:
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
- *greens
|
||||
- *reds
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
routers: []
|
||||
hosts:
|
||||
- hostname: client
|
||||
- hostname: server
|
||||
num_services: 1
|
||||
num_applications: 1
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 1
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- client:eth-1<->switch_1:eth-1
|
||||
- server:eth-1<->switch_1:eth-2
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_SHUTDOWN
|
||||
options:
|
||||
node_id: 0
|
||||
2:
|
||||
action: NODE_SHUTDOWN
|
||||
options:
|
||||
node_id: 1
|
||||
3:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 0
|
||||
4:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 1
|
||||
5:
|
||||
action: HOST_NIC_DISABLE
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
6:
|
||||
action: HOST_NIC_DISABLE
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
7:
|
||||
action: HOST_NIC_ENABLE
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
8:
|
||||
action: HOST_NIC_ENABLE
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
- node_name: server
|
||||
|
||||
max_folders_per_node: 0
|
||||
max_files_per_folder: 0
|
||||
max_services_per_node: 0
|
||||
max_nics_per_node: 1
|
||||
max_acl_rules: 0
|
||||
ip_list:
|
||||
- 192.168.1.2
|
||||
- 192.168.1.3
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.40
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: false
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: client
|
||||
type: computer
|
||||
ip_address: 192.168.1.2
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
applications:
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.3
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
server_ip: 192.168.1.3
|
||||
payload: "DELETE"
|
||||
|
||||
- hostname: switch_1
|
||||
type: switch
|
||||
num_ports: 2
|
||||
|
||||
- hostname: server
|
||||
type: server
|
||||
ip_address: 192.168.1.3
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- type: DatabaseService
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: client
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
@@ -0,0 +1,14 @@
|
||||
base_scenario: scenario.yaml
|
||||
schedule:
|
||||
0:
|
||||
- greens_0.yaml
|
||||
- reds_0.yaml
|
||||
1:
|
||||
- greens_0.yaml
|
||||
- reds_1.yaml
|
||||
2:
|
||||
- greens_1.yaml
|
||||
- reds_1.yaml
|
||||
3:
|
||||
- greens_2.yaml
|
||||
- reds_2.yaml
|
||||
372
src/primaite/notebooks/Using-Episode-Schedules.ipynb
Normal file
372
src/primaite/notebooks/Using-Episode-Schedules.ipynb
Normal file
@@ -0,0 +1,372 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Using Episode Schedules\n",
|
||||
"\n",
|
||||
"PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase \n",
|
||||
"domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.\n",
|
||||
"\n",
|
||||
"When using a fixed scenario, a single yaml config file is used. However, to use episode schedules, PrimAITE uses a \n",
|
||||
"directory with several config files that work together."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Defining variations in the config file.\n",
|
||||
"\n",
|
||||
"### Base scenario\n",
|
||||
"The base scenario is essentially the same as a fixed YAML configuration, but it can contain placeholders that are \n",
|
||||
"populated with episode-specific data at runtime. The base scenario contains any network, agent, or settings that\n",
|
||||
"remain fixed for the entire training/evaluation session.\n",
|
||||
"\n",
|
||||
"The placeholders are defined as YAML Aliases and they are denoted by an asterisk (`*placeholder`).\n",
|
||||
"\n",
|
||||
"### Variations\n",
|
||||
"For each variation that could be used in a placeholder, there is a separate yaml file that contains the data that should populate the placeholder.\n",
|
||||
"\n",
|
||||
"The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand (`&anchor`).\n",
|
||||
"\n",
|
||||
"[Learn more about YAML Aliases and Anchors here.](https://www.educative.io/blog/advanced-yaml-syntax-cheatsheet#:~:text=YAML%20Anchors%20and%20Alias)\n",
|
||||
"\n",
|
||||
"### Schedule\n",
|
||||
"Users must define which combination of scenario variations should be loaded in each episode. This takes the form of a\n",
|
||||
"YAML file with a relative path to the base scenario and a list of paths to be loaded in during each episode.\n",
|
||||
"\n",
|
||||
"It takes the following format:\n",
|
||||
"```yaml\n",
|
||||
"base_scenario: base.yaml\n",
|
||||
"schedule:\n",
|
||||
" 0: # list of variations to load in at episode 0 (before the first call to env.reset() happens)\n",
|
||||
" - laydown_1.yaml\n",
|
||||
" - attack_1.yaml\n",
|
||||
" 1: # list of variations to load in at episode 1 (after the first env.reset() call)\n",
|
||||
" - laydown_2.yaml\n",
|
||||
" - attack_2.yaml\n",
|
||||
"```\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Demonstration"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run `primaite setup` to copy the example config files into the correct directory. Then, import and define config location."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!primaite setup"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import yaml\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n",
|
||||
"from primaite import PRIMAITE_PATHS\n",
|
||||
"from prettytable import PrettyTable\n",
|
||||
"scenario_path = PRIMAITE_PATHS.user_config_path / \"example_config/scenario_with_placeholders\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Base Scenario File\n",
|
||||
"Let's view the contents of the base scenario file:\n",
|
||||
"\n",
|
||||
"It contains all the base settings that stay fixed throughout all episodes, including the `io_settings`, `game` settings, the network layout and the blue agent definition. There are two placeholders: `*greens` and `*reds`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(scenario_path/\"scenario.yaml\") as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Schedule File\n",
|
||||
"Let's view the contents of the schedule file:\n",
|
||||
"\n",
|
||||
"This file references the base scenario file and defines which variations should be loaded in at each episode. In this instance, there are four episodes, during the first episode `greens_0` and `reds_0` is used, during the second episode `greens_0` and `reds_1` is used, and so on."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(scenario_path/\"schedule.yaml\") as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Green Agent Variation Files\n",
|
||||
"\n",
|
||||
"There are three different variants of the green agent setup. In `greens_0`, there are no green agents, in `greens_1` there is a green agent that executes the database client application 80% of the time, and in `greens_2` there is a green agent that executes the database client application 5% of the time.\n",
|
||||
"\n",
|
||||
"(the difference between `greens_1` and `greens_2` is in the agent name and action probabilities)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(scenario_path/\"greens_0.yaml\") as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(scenario_path/\"greens_1.yaml\") as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(scenario_path/\"greens_2.yaml\") as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Red Agent Variation Files\n",
|
||||
"\n",
|
||||
"There are three different variants of the red agent setup. In `reds_0`, there are no red agents, in `reds_1` there is a red agent that executes every 20 steps, but in `reds_2` there is a red agent that executes every 2 steps."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(scenario_path/\"reds_0.yaml\") as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(scenario_path/\"reds_1.yaml\") as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(scenario_path/\"reds_2.yaml\") as f:\n",
|
||||
" print(f.read())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Running the simulation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Create the environment using the variable config."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = PrimaiteGymEnv(env_config=scenario_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Episode 0\n",
|
||||
"Let' run the episodes to verify that the agents are changing as expected. In episode 0, there should be no green or red agents, just the defender blue agent."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(f\"Current episode number: {env.episode_counter}\")\n",
|
||||
"print(f\"Agents present: {list(env.game.agents.keys())}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Episode 1\n",
|
||||
"When we reset the environment, it moves onto episode 1, where it will bring in reds_1 for red agent definition.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()\n",
|
||||
"print(f\"Current episode number: {env.episode_counter}\")\n",
|
||||
"print(f\"Agents present: {list(env.game.agents.keys())}\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Episode 2\n",
|
||||
"When we reset the environment again, it moves onto episode 2, where it will bring in greens_1 and reds_1 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
|
||||
"\n",
|
||||
"Most green actions will be `NODE_APPLICATION_EXECUTE` while red will `DONOTHING` except at steps 10 and 20."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()\n",
|
||||
"print(f\"Current episode number: {env.episode_counter}\")\n",
|
||||
"print(f\"Agents present: {list(env.game.agents.keys())}\")\n",
|
||||
"for i in range(21):\n",
|
||||
" env.step(0)\n",
|
||||
"\n",
|
||||
"table = PrettyTable()\n",
|
||||
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
|
||||
"for i in range(21):\n",
|
||||
" green_action = env.game.agents['green_A'].action_history[i].action\n",
|
||||
" red_action = env.game.agents['red_A'].action_history[i].action\n",
|
||||
" table.add_row([i, green_action, red_action])\n",
|
||||
"print(table)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Episode 3\n",
|
||||
"When we reset the environment again, it moves onto episode 3, where it will bring in greens_2 and reds_2 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
|
||||
"\n",
|
||||
"Now, green will perform `NODE_APPLICATION_EXECUTE` only 5% of the time, while red will perform `NODE_APPLICATION_EXECUTE` more frequently than before."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset()\n",
|
||||
"print(f\"Current episode number: {env.episode_counter}\")\n",
|
||||
"print(f\"Agents present: {list(env.game.agents.keys())}\")\n",
|
||||
"for i in range(21):\n",
|
||||
" env.step(0)\n",
|
||||
"\n",
|
||||
"table = PrettyTable()\n",
|
||||
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
|
||||
"for i in range(21):\n",
|
||||
" green_action = env.game.agents['green_B'].action_history[i].action\n",
|
||||
" red_action = env.game.agents['red_B'].action_history[i].action\n",
|
||||
" table.add_row([i, green_action, red_action])\n",
|
||||
"print(table)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Further Episodes\n",
|
||||
"\n",
|
||||
"Since the schedule definition only goes up to episode 3, if we reset the environment again, we run out of episodes. The environment will simply loop back to the beginning, but it produces a warning message to make users aware that the episodes are being repeated."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env.reset(); # semicolon suppresses jupyter outputting the observation space.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
@@ -9,6 +9,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
@@ -23,17 +24,14 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, game_config: Dict):
|
||||
def __init__(self, env_config: Union[Dict, str, PathLike]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
|
||||
self.game_config: Dict = game_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
"""Current game."""
|
||||
self._agent_name = next(iter(self.game.rl_agents))
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
@@ -94,9 +92,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state=state)
|
||||
next_obs = self._get_obs()
|
||||
@@ -141,8 +139,8 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(game_config=env_config)
|
||||
self.env.episode_counter -= 1
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
@@ -158,6 +156,11 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
@property
|
||||
def game(self) -> PrimaiteGame:
|
||||
"""Pass through game from env."""
|
||||
return self.env.game
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
@@ -169,16 +172,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.game_config: Dict = env_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
@@ -204,9 +207,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
123
src/primaite/session/episode_schedule.py
Normal file
123
src/primaite/session/episode_schedule.py
Normal file
@@ -0,0 +1,123 @@
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Mapping, Sequence, Union
|
||||
|
||||
import pydantic
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class EpisodeScheduler(pydantic.BaseModel, ABC):
|
||||
"""
|
||||
Episode schedulers provide functionality to select different scenarios and game setups for each episode.
|
||||
|
||||
This is useful when implementing advanced RL concepts like curriculum learning and domain randomisation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, episode_num: int) -> Dict:
|
||||
"""Return the config that should be used during this episode."""
|
||||
...
|
||||
|
||||
|
||||
class ConstantEpisodeScheduler(EpisodeScheduler):
|
||||
"""The constant episode schedule simply provides the same game setup every time."""
|
||||
|
||||
config: Dict
|
||||
|
||||
def __call__(self, episode_num: int) -> Dict:
|
||||
"""Return the same config every time."""
|
||||
return copy.deepcopy(self.config)
|
||||
|
||||
|
||||
class EpisodeListScheduler(EpisodeScheduler):
|
||||
"""Cycle through a list of different game setups for each episode."""
|
||||
|
||||
schedule: Mapping[int, List[str]]
|
||||
"""Mapping from episode number to list of filenames"""
|
||||
episode_data: Mapping[str, str]
|
||||
"""Mapping from filename to yaml string."""
|
||||
base_scenario: str
|
||||
"""yaml string containing the base scenario."""
|
||||
|
||||
_exceeded_episode_list: bool = False
|
||||
"""
|
||||
Flag that's set to true when attempting to keep generating episodes after schedule runs out.
|
||||
|
||||
When this happens, we loop back to the beginning, but a warning is raised.
|
||||
"""
|
||||
|
||||
def __call__(self, episode_num: int) -> Dict:
|
||||
"""Return the config for the given episode number."""
|
||||
if episode_num >= len(self.schedule):
|
||||
if not self._exceeded_episode_list:
|
||||
self._exceeded_episode_list = True
|
||||
_LOGGER.warn(
|
||||
f"Running episode {episode_num} but the schedule only defines "
|
||||
f"{len(self.schedule)} episodes. Looping back to the beginning"
|
||||
)
|
||||
# not sure if we should be using a traditional warning, or a _LOGGER.warning
|
||||
episode_num = episode_num % len(self.schedule)
|
||||
|
||||
filenames_to_join = self.schedule[episode_num]
|
||||
yaml_data_to_join = [self.episode_data[fn] for fn in filenames_to_join] + [self.base_scenario]
|
||||
joined_yaml = "\n".join(yaml_data_to_join)
|
||||
parsed_cfg = yaml.safe_load(joined_yaml)
|
||||
|
||||
# Unfortunately, using placeholders like this is slightly hacky, so we have to flatten the list of agents
|
||||
flat_agents_list = []
|
||||
for a in parsed_cfg["agents"]:
|
||||
if isinstance(a, Sequence):
|
||||
flat_agents_list.extend(a)
|
||||
else:
|
||||
flat_agents_list.append(a)
|
||||
parsed_cfg["agents"] = flat_agents_list
|
||||
|
||||
return parsed_cfg
|
||||
|
||||
|
||||
def build_scheduler(config: Union[str, Path, Dict]) -> EpisodeScheduler:
|
||||
"""
|
||||
Convenience method to build an EpisodeScheduler with a dict, file path, or folder path.
|
||||
|
||||
If a path to a folder is provided, it will be treated as a list of game scenarios.
|
||||
Otherwise, if a dict or a single file is provided, it will be treated as a constant game scenario.
|
||||
"""
|
||||
# If we get a dict, return a constant episode schedule that repeats that one config forever
|
||||
if isinstance(config, Dict):
|
||||
return ConstantEpisodeScheduler(config=config)
|
||||
|
||||
# Cast string to Path
|
||||
if isinstance(config, str):
|
||||
config = Path(config)
|
||||
|
||||
if not config.exists():
|
||||
raise FileNotFoundError(f"Provided config path {config} could not be found.")
|
||||
|
||||
if config.is_file():
|
||||
with open(config, "r") as f:
|
||||
cfg_data = yaml.safe_load(f)
|
||||
return ConstantEpisodeScheduler(config=cfg_data)
|
||||
|
||||
if not config.is_dir():
|
||||
raise RuntimeError("Something went wrong while building Primaite config.")
|
||||
|
||||
root = config
|
||||
schedule_path = root / "schedule.yaml"
|
||||
|
||||
with open(schedule_path, "r") as f:
|
||||
schedule = yaml.safe_load(f)
|
||||
|
||||
base_scenario_path = root / schedule["base_scenario"]
|
||||
files_to_load = set(chain.from_iterable(schedule["schedule"].values()))
|
||||
|
||||
episode_data = {fp: (root / fp).read_text() for fp in files_to_load}
|
||||
|
||||
return EpisodeListScheduler(
|
||||
schedule=schedule["schedule"], episode_data=episode_data, base_scenario=base_scenario_path.read_text()
|
||||
)
|
||||
@@ -6,7 +6,8 @@ from typing import Dict, List, Optional
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator import LogLevel, SIM_OUTPUT
|
||||
from src.primaite.utils.primaite_config_utils import is_dev_mode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -35,6 +36,8 @@ class PrimaiteIO:
|
||||
"""Whether to save system logs."""
|
||||
write_sys_log_to_terminal: bool = False
|
||||
"""Whether to write the sys log to the terminal."""
|
||||
sys_log_level: LogLevel = LogLevel.INFO
|
||||
"""The level of log that should be included in the logfiles/logged into terminal."""
|
||||
|
||||
def __init__(self, settings: Optional[Settings] = None) -> None:
|
||||
"""
|
||||
@@ -50,6 +53,7 @@ class PrimaiteIO:
|
||||
SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs
|
||||
SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs
|
||||
SIM_OUTPUT.write_sys_log_to_terminal = self.settings.write_sys_log_to_terminal
|
||||
SIM_OUTPUT.sys_log_level = self.settings.sys_log_level
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
@@ -57,7 +61,14 @@ class PrimaiteIO:
|
||||
timestamp = datetime.now()
|
||||
date_str = timestamp.strftime("%Y-%m-%d")
|
||||
time_str = timestamp.strftime("%H-%M-%S")
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
|
||||
# check if running in dev mode
|
||||
if is_dev_mode():
|
||||
# if dev mode, simulation output will be the current working directory
|
||||
session_path = Path.cwd() / "simulation_output" / date_str / time_str
|
||||
else:
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
return session_path
|
||||
|
||||
@@ -96,6 +107,10 @@ class PrimaiteIO:
|
||||
def from_config(cls, config: Dict) -> "PrimaiteIO":
|
||||
"""Create an instance of PrimaiteIO based on a configuration dict."""
|
||||
config = config or {}
|
||||
|
||||
if config.get("sys_log_level"):
|
||||
config["sys_log_level"] = LogLevel[config["sys_log_level"].upper()] # convert to enum
|
||||
|
||||
new = cls(settings=cls.Settings(**config))
|
||||
|
||||
return new
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# The main PrimAITE application config file
|
||||
|
||||
developer_mode: False # false by default
|
||||
|
||||
# Logging
|
||||
logging:
|
||||
log_level: INFO
|
||||
@@ -9,14 +11,3 @@ logging:
|
||||
WARNING: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
ERROR: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
CRITICAL: '%(asctime)s::%(levelname)s::%(name)s::%(lineno)s::%(message)s'
|
||||
|
||||
# Session
|
||||
session:
|
||||
outputs:
|
||||
plots:
|
||||
size:
|
||||
auto_size: false
|
||||
width: 1500
|
||||
height: 900
|
||||
template: plotly_white
|
||||
range_slider: false
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory."""
|
||||
from datetime import datetime
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import _PRIMAITE_ROOT
|
||||
@@ -7,6 +8,21 @@ from primaite import _PRIMAITE_ROOT
|
||||
__all__ = ["SIM_OUTPUT"]
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
"""Enum containing all the available log levels for PrimAITE simulation output."""
|
||||
|
||||
DEBUG = 10
|
||||
"""Debug items will be output to terminal or log file."""
|
||||
INFO = 20
|
||||
"""Info items will be output to terminal or log file."""
|
||||
WARNING = 30
|
||||
"""Warnings will be output to terminal or log file."""
|
||||
ERROR = 40
|
||||
"""Errors will be output to terminal or log file."""
|
||||
CRITICAL = 50
|
||||
"""Critical errors will be output to terminal or log file."""
|
||||
|
||||
|
||||
class _SimOutput:
|
||||
def __init__(self):
|
||||
self._path: Path = (
|
||||
@@ -15,6 +31,7 @@ class _SimOutput:
|
||||
self.save_pcap_logs: bool = False
|
||||
self.save_sys_logs: bool = False
|
||||
self.write_sys_log_to_terminal: bool = False
|
||||
self.sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
|
||||
@@ -6,7 +6,6 @@ from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file import File
|
||||
@@ -14,8 +13,6 @@ from primaite.simulator.file_system.file_type import FileType
|
||||
from primaite.simulator.file_system.folder import Folder
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class FileSystem(SimComponent):
|
||||
"""Class that contains all the simulation File System."""
|
||||
@@ -163,11 +160,11 @@ class FileSystem(SimComponent):
|
||||
:param folder_name: The name of the folder.
|
||||
"""
|
||||
if folder_name == "root":
|
||||
self.sys_log.warning("Cannot delete the root folder.")
|
||||
self.sys_log.error("Cannot delete the root folder.")
|
||||
return False
|
||||
folder = self.get_folder(folder_name)
|
||||
if not folder:
|
||||
_LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}")
|
||||
self.sys_log.error(f"Cannot delete folder as it does not exist: {folder_name}")
|
||||
return False
|
||||
|
||||
# set folder to deleted state
|
||||
@@ -180,7 +177,7 @@ class FileSystem(SimComponent):
|
||||
folder.remove_all_files()
|
||||
|
||||
self.deleted_folders[folder.uuid] = folder
|
||||
self.sys_log.info(f"Deleted folder /{folder.name} and its contents")
|
||||
self.sys_log.warning(f"Deleted folder /{folder.name} and its contents")
|
||||
return True
|
||||
|
||||
def delete_folder_by_id(self, folder_uuid: str) -> None:
|
||||
@@ -283,7 +280,7 @@ class FileSystem(SimComponent):
|
||||
folder = self.get_folder(folder_name, include_deleted=include_deleted)
|
||||
if folder:
|
||||
return folder.get_file(file_name, include_deleted=include_deleted)
|
||||
self.sys_log.info(f"File not found /{folder_name}/{file_name}")
|
||||
self.sys_log.warning(f"File not found /{folder_name}/{file_name}")
|
||||
|
||||
def get_file_by_id(
|
||||
self, file_uuid: str, folder_uuid: Optional[str] = None, include_deleted: Optional[bool] = False
|
||||
@@ -499,7 +496,7 @@ class FileSystem(SimComponent):
|
||||
"""
|
||||
folder = self.get_folder(folder_name=folder_name)
|
||||
if not folder:
|
||||
_LOGGER.debug(f"Cannot restore file {file_name} in folder {folder_name} as the folder does not exist.")
|
||||
self.sys_log.error(f"Cannot restore file {file_name} in folder {folder_name} as the folder does not exist.")
|
||||
return False
|
||||
|
||||
file = folder.get_file(file_name=file_name, include_deleted=True)
|
||||
|
||||
@@ -5,14 +5,11 @@ from typing import Dict, Optional
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.file_system.file import File
|
||||
from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class Folder(FileSystemItemABC):
|
||||
"""Simulation Folder."""
|
||||
@@ -255,7 +252,7 @@ class Folder(FileSystemItemABC):
|
||||
file.delete()
|
||||
self.sys_log.info(f"Removed file {file.name} (id: {file.uuid})")
|
||||
else:
|
||||
_LOGGER.debug(f"File with UUID {file.uuid} was not found.")
|
||||
self.sys_log.error(f"File with UUID {file.uuid} was not found.")
|
||||
|
||||
def remove_file_by_id(self, file_uuid: str):
|
||||
"""
|
||||
|
||||
@@ -161,7 +161,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
|
||||
return
|
||||
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
self._connected_node.sys_log.info(
|
||||
self._connected_node.sys_log.error(
|
||||
f"Interface {self} cannot be enabled as the connected Node is not powered on"
|
||||
)
|
||||
return
|
||||
|
||||
@@ -307,13 +307,13 @@ class WiredNetworkInterface(NetworkInterface, ABC):
|
||||
return False
|
||||
|
||||
if self._connected_node.operating_state != NodeOperatingState.ON:
|
||||
self._connected_node.sys_log.info(
|
||||
self._connected_node.sys_log.warning(
|
||||
f"Interface {self} cannot be enabled as the connected Node is not powered on"
|
||||
)
|
||||
return False
|
||||
|
||||
if not self._connected_link:
|
||||
self._connected_node.sys_log.info(f"Interface {self} cannot be enabled as there is no Link connected.")
|
||||
self._connected_node.sys_log.warning(f"Interface {self} cannot be enabled as there is no Link connected.")
|
||||
return False
|
||||
|
||||
self.enabled = True
|
||||
@@ -1201,7 +1201,7 @@ class Node(SimComponent):
|
||||
self._nic_request_manager.add_request(new_nic_num, RequestType(func=network_interface._request_manager))
|
||||
else:
|
||||
msg = f"Cannot connect NIC {network_interface} as it is already connected"
|
||||
self.sys_log.logger.error(msg)
|
||||
self.sys_log.logger.warning(msg)
|
||||
raise NetworkError(msg)
|
||||
|
||||
def disconnect_nic(self, network_interface: Union[NetworkInterface, str]):
|
||||
@@ -1228,7 +1228,7 @@ class Node(SimComponent):
|
||||
self._nic_request_manager.remove_request(network_interface_num)
|
||||
else:
|
||||
msg = f"Cannot disconnect Network Interface {network_interface} as it is not connected"
|
||||
self.sys_log.logger.error(msg)
|
||||
self.sys_log.logger.warning(msg)
|
||||
raise NetworkError(msg)
|
||||
|
||||
def ping(self, target_ip_address: Union[IPv4Address, str], pings: int = 4) -> bool:
|
||||
@@ -1299,7 +1299,6 @@ class Node(SimComponent):
|
||||
self.services.pop(service.uuid)
|
||||
service.parent = None
|
||||
self.sys_log.info(f"Uninstalled service {service.name}")
|
||||
_LOGGER.info(f"Removed service {service.name} from node {self.hostname}")
|
||||
self._service_request_manager.remove_request(service.name)
|
||||
|
||||
def install_application(self, application: Application) -> None:
|
||||
@@ -1335,7 +1334,6 @@ class Node(SimComponent):
|
||||
self.applications.pop(application.uuid)
|
||||
application.parent = None
|
||||
self.sys_log.info(f"Uninstalled application {application.name}")
|
||||
_LOGGER.info(f"Removed application {application.name} from node {self.hostname}")
|
||||
self._application_request_manager.remove_request(application.name)
|
||||
|
||||
def application_install_action(self, application: Application, ip_address: Optional[str] = None) -> bool:
|
||||
@@ -1360,7 +1358,6 @@ class Node(SimComponent):
|
||||
self.software_manager.install(application)
|
||||
application_instance = self.software_manager.software.get(str(application.__name__))
|
||||
self.applications[application_instance.uuid] = application_instance
|
||||
self.sys_log.info(f"Installed application {application_instance.name}")
|
||||
_LOGGER.debug(f"Added application {application_instance.name} to node {self.hostname}")
|
||||
self._application_request_manager.add_request(
|
||||
application_instance.name, RequestType(func=application_instance._request_manager)
|
||||
|
||||
@@ -147,7 +147,7 @@ class HostARP(ARP):
|
||||
super()._process_arp_request(arp_packet, from_network_interface)
|
||||
# Unmatched ARP Request
|
||||
if arp_packet.target_ip_address != from_network_interface.ip_address:
|
||||
self.sys_log.info(
|
||||
self.sys_log.warning(
|
||||
f"Ignoring ARP request for {arp_packet.target_ip_address}. Current IP address is "
|
||||
f"{from_network_interface.ip_address}"
|
||||
)
|
||||
|
||||
@@ -933,7 +933,7 @@ class RouterICMP(ICMP):
|
||||
)
|
||||
|
||||
if not network_interface:
|
||||
self.sys_log.error(
|
||||
self.sys_log.warning(
|
||||
"Cannot send ICMP echo reply as there is no outbound Network Interface to use. Try configuring the "
|
||||
"default gateway."
|
||||
)
|
||||
@@ -1482,7 +1482,7 @@ class Router(NetworkNode):
|
||||
frame.ethernet.dst_mac_addr = target_mac
|
||||
network_interface.send_frame(frame)
|
||||
else:
|
||||
self.sys_log.error(f"Frame dropped as there is no route to {frame.ip.dst_ip_address}")
|
||||
self.sys_log.warning(f"Frame dropped as there is no route to {frame.ip.dst_ip_address}")
|
||||
|
||||
def configure_port(self, port: int, ip_address: Union[IPv4Address, str], subnet_mask: Union[IPv4Address, str]):
|
||||
"""
|
||||
|
||||
@@ -74,7 +74,7 @@ class SwitchPort(WiredNetworkInterface):
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
self._connected_node.sys_log.warning("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
self.pcap.capture_inbound(frame)
|
||||
self._connected_node.receive_frame(frame=frame, from_network_interface=self)
|
||||
|
||||
@@ -68,7 +68,7 @@ class WirelessAccessPoint(IPWirelessNetworkInterface):
|
||||
if self.enabled:
|
||||
frame.decrement_ttl()
|
||||
if frame.ip and frame.ip.ttl < 1:
|
||||
self._connected_node.sys_log.info("Frame discarded as TTL limit reached")
|
||||
self._connected_node.sys_log.warning("Frame discarded as TTL limit reached")
|
||||
return False
|
||||
frame.set_received_timestamp()
|
||||
self.pcap.capture_inbound(frame)
|
||||
|
||||
@@ -2,13 +2,10 @@ from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Set
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class ApplicationOperatingState(Enum):
|
||||
"""Enumeration of Application Operating States."""
|
||||
@@ -99,7 +96,7 @@ class Application(IOSoftware):
|
||||
|
||||
if self.operating_state is not self.operating_state.RUNNING:
|
||||
# service is not running
|
||||
_LOGGER.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
self.sys_log.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -131,7 +128,6 @@ class Application(IOSoftware):
|
||||
"""Install Application."""
|
||||
super().install()
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.sys_log.info(f"Installing Application {self.name}")
|
||||
self.operating_state = ApplicationOperatingState.INSTALLING
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
|
||||
@@ -1,16 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
from pydantic import BaseModel
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.software_manager import SoftwareManager
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
class DatabaseClientConnection(BaseModel):
|
||||
"""
|
||||
DatabaseClientConnection Class.
|
||||
|
||||
This class is used to record current DatabaseConnections within the DatabaseClient class.
|
||||
"""
|
||||
|
||||
connection_id: str
|
||||
"""Connection UUID."""
|
||||
|
||||
parent_node: HostNode
|
||||
"""The parent Node that this connection was created on."""
|
||||
|
||||
is_active: bool = True
|
||||
"""Flag to state whether the connection is still active or not."""
|
||||
|
||||
@property
|
||||
def client(self) -> Optional[DatabaseClient]:
|
||||
"""The DatabaseClient that holds this connection."""
|
||||
return self.parent_node.software_manager.software.get("DatabaseClient")
|
||||
|
||||
def query(self, sql: str) -> bool:
|
||||
"""
|
||||
Query the databaseserver.
|
||||
|
||||
:return: Boolean value
|
||||
"""
|
||||
if self.is_active and self.client:
|
||||
return self.client._query(connection_id=self.connection_id, sql=sql) # noqa
|
||||
return False
|
||||
|
||||
def disconnect(self):
|
||||
"""Disconnect the connection."""
|
||||
if self.client and self.is_active:
|
||||
self.client._disconnect(self.connection_id) # noqa
|
||||
|
||||
|
||||
class DatabaseClient(Application):
|
||||
@@ -25,13 +65,21 @@ class DatabaseClient(Application):
|
||||
|
||||
server_ip_address: Optional[IPv4Address] = None
|
||||
server_password: Optional[str] = None
|
||||
connected: bool = False
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
_last_connection_successful: Optional[bool] = None
|
||||
_query_success_tracker: Dict[str, bool] = {}
|
||||
"""Keep track of connections that were established or verified during this step. Used for rewards."""
|
||||
last_query_response: Optional[Dict] = None
|
||||
"""Keep track of the latest query response. Used to determine rewards."""
|
||||
_server_connection_id: Optional[str] = None
|
||||
"""Connection ID to the Database Server."""
|
||||
client_connections: Dict[str, DatabaseClientConnection] = {}
|
||||
"""Keep track of active connections to Database Server."""
|
||||
_client_connection_requests: Dict[str, Optional[str]] = {}
|
||||
"""Dictionary of connection requests to Database Server."""
|
||||
connected: bool = False
|
||||
"""Boolean Value for whether connected to DB Server."""
|
||||
native_connection: Optional[DatabaseClientConnection] = None
|
||||
"""Native Client Connection for using the client directly (similar to psql in a terminal)."""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
kwargs["name"] = "DatabaseClient"
|
||||
@@ -51,12 +99,18 @@ class DatabaseClient(Application):
|
||||
|
||||
def execute(self) -> bool:
|
||||
"""Execution definition for db client: perform a select query."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
self.num_executions += 1 # trying to connect counts as an execution
|
||||
if not self._server_connection_id:
|
||||
|
||||
if not self.native_connection:
|
||||
self.connect()
|
||||
can_connect = self.check_connection(connection_id=self._server_connection_id)
|
||||
self._last_connection_successful = can_connect
|
||||
return can_connect
|
||||
|
||||
if self.native_connection:
|
||||
return self.check_connection(connection_id=self.native_connection.connection_id)
|
||||
|
||||
return False
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -69,6 +123,23 @@ class DatabaseClient(Application):
|
||||
state["last_connection_successful"] = self._last_connection_successful
|
||||
return state
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
"""
|
||||
Display the client connections in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
table = PrettyTable(["Connection ID", "Active"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} {self.name} Client Connections"
|
||||
if self.native_connection:
|
||||
table.add_row([self.native_connection.connection_id, self.native_connection.is_active])
|
||||
for connection_id, connection in self.client_connections.items():
|
||||
table.add_row([connection_id, connection.is_active])
|
||||
print(table.get_string(sortby="Connection ID"))
|
||||
|
||||
def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None):
|
||||
"""
|
||||
Configure the DatabaseClient to communicate with a DatabaseService.
|
||||
@@ -81,21 +152,17 @@ class DatabaseClient(Application):
|
||||
self.sys_log.info(f"{self.name}: Configured the {self.name} with {server_ip_address=}, {server_password=}.")
|
||||
|
||||
def connect(self) -> bool:
|
||||
"""Connect to a Database Service."""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
"""Connect the native client connection."""
|
||||
if self.native_connection:
|
||||
return True
|
||||
self.native_connection = self.get_new_connection()
|
||||
return self.native_connection is not None
|
||||
|
||||
if not self._server_connection_id:
|
||||
self._server_connection_id = str(uuid4())
|
||||
|
||||
self.connected = self._connect(
|
||||
server_ip_address=self.server_ip_address,
|
||||
password=self.server_password,
|
||||
connection_id=self._server_connection_id,
|
||||
)
|
||||
if not self.connected:
|
||||
self._server_connection_id = None
|
||||
return self.connected
|
||||
def disconnect(self):
|
||||
"""Disconnect the native client connection."""
|
||||
if self.native_connection:
|
||||
self._disconnect(self.native_connection.connection_id)
|
||||
self.native_connection = None
|
||||
|
||||
def check_connection(self, connection_id: str) -> bool:
|
||||
"""Check whether the connection can be successfully re-established.
|
||||
@@ -107,15 +174,19 @@ class DatabaseClient(Application):
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
return self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
|
||||
return self._query("SELECT * FROM pg_stat_activity", connection_id=connection_id)
|
||||
|
||||
def _check_client_connection(self, connection_id: str) -> bool:
|
||||
"""Check that client_connection_id is valid."""
|
||||
return True if connection_id in self._client_connection_requests else False
|
||||
|
||||
def _connect(
|
||||
self,
|
||||
server_ip_address: IPv4Address,
|
||||
connection_id: Optional[str] = None,
|
||||
connection_request_id: str,
|
||||
password: Optional[str] = None,
|
||||
is_reattempt: bool = False,
|
||||
) -> bool:
|
||||
) -> Optional[DatabaseClientConnection]:
|
||||
"""
|
||||
Connects the DatabaseClient to the DatabaseServer.
|
||||
|
||||
@@ -129,56 +200,106 @@ class DatabaseClient(Application):
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if is_reattempt:
|
||||
if self._server_connection_id:
|
||||
valid_connection = self._check_client_connection(connection_id=connection_request_id)
|
||||
if valid_connection:
|
||||
database_client_connection = self._client_connection_requests.pop(connection_request_id)
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} authorised"
|
||||
f"{self.name}: DatabaseClient connection to {server_ip_address} authorised."
|
||||
f"Connection Request ID was {connection_request_id}."
|
||||
)
|
||||
self.server_ip_address = server_ip_address
|
||||
return True
|
||||
self.connected = True
|
||||
self._last_connection_successful = True
|
||||
return database_client_connection
|
||||
else:
|
||||
self.sys_log.info(
|
||||
f"{self.name} {connection_id=}: DatabaseClient connection to {server_ip_address} declined"
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: DatabaseClient connection to {server_ip_address} declined."
|
||||
f"Connection Request ID was {connection_request_id}."
|
||||
)
|
||||
return False
|
||||
payload = {
|
||||
"type": "connect_request",
|
||||
"password": password,
|
||||
"connection_id": connection_id,
|
||||
}
|
||||
self._last_connection_successful = False
|
||||
return None
|
||||
payload = {"type": "connect_request", "password": password, "connection_request_id": connection_request_id}
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload=payload, dest_ip_address=server_ip_address, dest_port=self.port
|
||||
)
|
||||
return self._connect(
|
||||
server_ip_address=server_ip_address, password=password, connection_id=connection_id, is_reattempt=True
|
||||
server_ip_address=server_ip_address,
|
||||
password=password,
|
||||
is_reattempt=True,
|
||||
connection_request_id=connection_request_id,
|
||||
)
|
||||
|
||||
def disconnect(self) -> bool:
|
||||
"""Disconnect from the Database Service."""
|
||||
def _disconnect(self, connection_id: str) -> bool:
|
||||
"""Disconnect from the Database Service.
|
||||
|
||||
If no connection_id is provided, connect from first ID in
|
||||
self.client_connections.
|
||||
|
||||
:param: connection_id: connection ID to disconnect.
|
||||
:type: connection_id: str
|
||||
|
||||
:return: bool
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
# if there are no connections - nothing to disconnect
|
||||
if not self._server_connection_id:
|
||||
self.sys_log.error(f"Unable to disconnect - {self.name} has no active connections.")
|
||||
if len(self.client_connections) == 0:
|
||||
self.sys_log.warning(f"{self.name}: Unable to disconnect, no active connections.")
|
||||
return False
|
||||
if not self.client_connections.get(connection_id):
|
||||
return False
|
||||
|
||||
# if no connection provided, disconnect the first connection
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": self._server_connection_id},
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
dest_ip_address=self.server_ip_address,
|
||||
dest_port=self.port,
|
||||
)
|
||||
self.remove_connection(connection_id=self._server_connection_id)
|
||||
connection = self.client_connections.pop(connection_id)
|
||||
self.terminate_connection(connection_id=connection_id)
|
||||
|
||||
self.sys_log.info(
|
||||
f"{self.name}: DatabaseClient disconnected {self._server_connection_id} from {self.server_ip_address}"
|
||||
)
|
||||
connection.is_active = False
|
||||
|
||||
self.sys_log.info(f"{self.name}: DatabaseClient disconnected {connection_id} from {self.server_ip_address}")
|
||||
self.connected = False
|
||||
return True
|
||||
|
||||
def _query(self, sql: str, query_id: str, connection_id: str, is_reattempt: bool = False) -> bool:
|
||||
def uninstall(self) -> None:
|
||||
"""
|
||||
Uninstall the DatabaseClient.
|
||||
|
||||
Calls disconnect on all client connections to ensure that both client and server connections are killed.
|
||||
"""
|
||||
while self.client_connections.values():
|
||||
client_connection = self.client_connections[next(iter(self.client_connections.keys()))]
|
||||
client_connection.disconnect()
|
||||
super().uninstall()
|
||||
|
||||
def get_new_connection(self) -> Optional[DatabaseClientConnection]:
|
||||
"""Get a new connection to the DatabaseServer.
|
||||
|
||||
:return: DatabaseClientConnection object
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return None
|
||||
connection_request_id = str(uuid4())
|
||||
self._client_connection_requests[connection_request_id] = None
|
||||
|
||||
return self._connect(
|
||||
server_ip_address=self.server_ip_address,
|
||||
password=self.server_password,
|
||||
connection_request_id=connection_request_id,
|
||||
)
|
||||
|
||||
def _create_client_connection(self, connection_id: str, connection_request_id: str) -> None:
|
||||
"""Create a new DatabaseClientConnection Object."""
|
||||
client_connection = DatabaseClientConnection(
|
||||
connection_id=connection_id, client=self, parent_node=self.software_manager.node
|
||||
)
|
||||
self.client_connections[connection_id] = client_connection
|
||||
self._client_connection_requests[connection_request_id] = client_connection
|
||||
|
||||
def _query(self, sql: str, connection_id: str, query_id: Optional[str] = False, is_reattempt: bool = False) -> bool:
|
||||
"""
|
||||
Send a query to the connected database server.
|
||||
|
||||
@@ -188,15 +309,22 @@ class DatabaseClient(Application):
|
||||
:param: query_id: ID of the query, used as reference
|
||||
:type: query_id: str
|
||||
|
||||
:param: connection_id: ID of the connection to the database server.
|
||||
:type: connection_id: str
|
||||
|
||||
:param: is_reattempt: True if the query request has been reattempted. Default False
|
||||
:type: is_reattempt: Optional[bool]
|
||||
"""
|
||||
if not query_id:
|
||||
query_id = str(uuid4())
|
||||
if is_reattempt:
|
||||
success = self._query_success_tracker.get(query_id)
|
||||
if success:
|
||||
self.sys_log.info(f"{self.name}: Query successful {sql}")
|
||||
self._last_connection_successful = True
|
||||
return True
|
||||
self.sys_log.info(f"{self.name}: Unable to run query {sql}")
|
||||
self.sys_log.error(f"{self.name}: Unable to run query {sql}")
|
||||
self._last_connection_successful = False
|
||||
return False
|
||||
else:
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
@@ -211,39 +339,29 @@ class DatabaseClient(Application):
|
||||
"""Run the DatabaseClient."""
|
||||
super().run()
|
||||
|
||||
def query(self, sql: str, connection_id: Optional[str] = None) -> bool:
|
||||
def query(self, sql: str) -> bool:
|
||||
"""
|
||||
Send a query to the Database Service.
|
||||
|
||||
:param: sql: The SQL query.
|
||||
:param: is_reattempt: If true, the action has been reattempted.
|
||||
:type: sql: str
|
||||
|
||||
:return: True if the query was successful, otherwise False.
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if not self.native_connection:
|
||||
return False
|
||||
|
||||
# reset last query response
|
||||
self.last_query_response = None
|
||||
|
||||
connection_id: str
|
||||
|
||||
if not connection_id:
|
||||
connection_id = self._server_connection_id
|
||||
|
||||
if not connection_id:
|
||||
self.connect()
|
||||
connection_id = self._server_connection_id
|
||||
|
||||
if not connection_id:
|
||||
msg = "Cannot run sql query, could not establish connection with the server."
|
||||
self.parent.sys_log.error(msg)
|
||||
return False
|
||||
|
||||
uuid = str(uuid4())
|
||||
self._query_success_tracker[uuid] = False
|
||||
return self._query(sql=sql, query_id=uuid, connection_id=connection_id)
|
||||
return self.native_connection.query(sql)
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
def receive(self, session_id: str, payload: Any, **kwargs) -> bool:
|
||||
"""
|
||||
Receive a payload from the Software Manager.
|
||||
|
||||
@@ -253,17 +371,23 @@ class DatabaseClient(Application):
|
||||
"""
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_response":
|
||||
if payload["response"] is True:
|
||||
# add connection
|
||||
self.add_connection(connection_id=payload.get("connection_id"), session_id=session_id)
|
||||
connection_id = payload["connection_id"]
|
||||
self._create_client_connection(
|
||||
connection_id=connection_id, connection_request_id=payload["connection_request_id"]
|
||||
)
|
||||
elif payload["type"] == "sql":
|
||||
self.last_query_response = payload
|
||||
query_id = payload.get("uuid")
|
||||
status_code = payload.get("status_code")
|
||||
self._query_success_tracker[query_id] = status_code == 200
|
||||
if self._query_success_tracker[query_id]:
|
||||
_LOGGER.debug(f"Received payload {payload}")
|
||||
self.sys_log.debug(f"Received {payload=}")
|
||||
elif payload["type"] == "disconnect":
|
||||
connection_id = payload["connection_id"]
|
||||
self.sys_log.info(f"{self.name}: Received disconnect command for {connection_id=} from the server")
|
||||
self._disconnect(payload["connection_id"])
|
||||
return True
|
||||
|
||||
@@ -9,7 +9,7 @@ from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -53,6 +53,7 @@ class DataManipulationBot(Application):
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -71,7 +72,7 @@ class DataManipulationBot(Application):
|
||||
"""Return the database client that is installed on the same machine as the DataManipulationBot."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
if db_client is None:
|
||||
_LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
self.sys_log.warning(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
return db_client
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
@@ -127,7 +128,7 @@ class DataManipulationBot(Application):
|
||||
"""
|
||||
if self.attack_stage == DataManipulationAttackStage.NOT_STARTED:
|
||||
# Bypass this stage as we're not dealing with logon for now
|
||||
self.sys_log.info(f"{self.name}: ")
|
||||
self.sys_log.debug(f"{self.name}: ")
|
||||
self.attack_stage = DataManipulationAttackStage.LOGON
|
||||
|
||||
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
|
||||
@@ -145,9 +146,14 @@ class DataManipulationBot(Application):
|
||||
# perform the port scan
|
||||
port_is_open = True # Temporary; later we can implement NMAP port scan.
|
||||
if port_is_open:
|
||||
self.sys_log.info(f"{self.name}: ")
|
||||
self.sys_log.debug(f"{self.name}: ")
|
||||
self.attack_stage = DataManipulationAttackStage.PORT_SCAN
|
||||
|
||||
def _establish_db_connection(self) -> bool:
|
||||
"""Establish a db connection to the Database Server."""
|
||||
self._db_connection = self._host_db_client.get_new_connection()
|
||||
return True if self._db_connection else False
|
||||
|
||||
def _perform_data_manipulation(self, p_of_success: Optional[float] = 0.1):
|
||||
"""
|
||||
Execute the data manipulation attack on the target.
|
||||
@@ -167,17 +173,16 @@ class DataManipulationBot(Application):
|
||||
if simulate_trial(p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Performing data manipulation")
|
||||
# perform the attack
|
||||
if not len(self._host_db_client.connections):
|
||||
self._host_db_client.connect()
|
||||
if len(self._host_db_client.connections):
|
||||
self._host_db_client.query(self.payload)
|
||||
if not self._db_connection:
|
||||
self._establish_db_connection()
|
||||
if self._db_connection:
|
||||
attack_successful = self._db_connection.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Data manipulation successful")
|
||||
self.attack_stage = DataManipulationAttackStage.SUCCEEDED
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Data manipulation failed")
|
||||
self.sys_log.warning(f"{self.name}: Data manipulation failed")
|
||||
self.attack_stage = DataManipulationAttackStage.FAILED
|
||||
|
||||
def run(self):
|
||||
@@ -191,7 +196,9 @@ class DataManipulationBot(Application):
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.")
|
||||
self.sys_log.warning(
|
||||
"Data manipulation application attempted to execute but it cannot perform actions right now."
|
||||
)
|
||||
self.run()
|
||||
|
||||
self.num_executions += 1
|
||||
@@ -206,7 +213,7 @@ class DataManipulationBot(Application):
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
if self.server_ip_address and self.payload:
|
||||
self.sys_log.info(f"{self.name}: Running")
|
||||
self.sys_log.debug(f"{self.name}: Running")
|
||||
self._logon()
|
||||
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
|
||||
self._perform_data_manipulation(p_of_success=self.data_manipulation_p_of_success)
|
||||
@@ -220,7 +227,7 @@ class DataManipulationBot(Application):
|
||||
return True
|
||||
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
self.sys_log.warning(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
|
||||
@@ -122,7 +122,7 @@ class DoSBot(DatabaseClient):
|
||||
|
||||
# DoS bot cannot do anything without a target
|
||||
if not self.target_ip_address or not self.target_port:
|
||||
self.sys_log.error(
|
||||
self.sys_log.warning(
|
||||
f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}"
|
||||
)
|
||||
return True
|
||||
@@ -152,7 +152,7 @@ class DoSBot(DatabaseClient):
|
||||
# perform the port scan
|
||||
port_is_open = True # Temporary; later we can implement NMAP port scan.
|
||||
if port_is_open:
|
||||
self.sys_log.info(f"{self.name}: ")
|
||||
self.sys_log.debug(f"{self.name}: ")
|
||||
self.attack_stage = DoSAttackStage.PORT_SCAN
|
||||
|
||||
def _perform_dos(self):
|
||||
|
||||
@@ -2,16 +2,13 @@ from enum import IntEnum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.science import simulate_trial
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
|
||||
class RansomwareAttackStage(IntEnum):
|
||||
@@ -76,6 +73,7 @@ class RansomwareScript(Application):
|
||||
kwargs["protocol"] = IPProtocol.NONE
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
"""
|
||||
@@ -94,7 +92,7 @@ class RansomwareScript(Application):
|
||||
"""Return the database client that is installed on the same machine as the Ransomware Script."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
if db_client is None:
|
||||
_LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
self.sys_log.warning(f"{self.__class__.__name__} cannot find a database client on its host.")
|
||||
return db_client
|
||||
|
||||
def _init_request_manager(self) -> RequestManager:
|
||||
@@ -158,7 +156,7 @@ class RansomwareScript(Application):
|
||||
self.attack_stage = RansomwareAttackStage.NOT_STARTED
|
||||
return True
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
self.sys_log.warning(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
|
||||
return False
|
||||
|
||||
def configure(
|
||||
@@ -254,11 +252,16 @@ class RansomwareScript(Application):
|
||||
def attack(self) -> bool:
|
||||
"""Perform the attack steps after opening the application."""
|
||||
if not self._can_perform_action():
|
||||
_LOGGER.debug("Ransomware application is unable to perform it's actions.")
|
||||
self.sys_log.warning("Ransomware application is unable to perform it's actions.")
|
||||
self.run()
|
||||
self.num_executions += 1
|
||||
return self._application_loop()
|
||||
|
||||
def _establish_db_connection(self) -> bool:
|
||||
"""Establish a db connection to the Database Server."""
|
||||
self._db_connection = self._host_db_client.get_new_connection()
|
||||
return True if self._db_connection else False
|
||||
|
||||
def _perform_ransomware_encrypt(self):
|
||||
"""
|
||||
Execute the Ransomware Encrypt payload on the target.
|
||||
@@ -276,12 +279,11 @@ class RansomwareScript(Application):
|
||||
if self.attack_stage == RansomwareAttackStage.COMMAND_AND_CONTROL:
|
||||
if simulate_trial(self.ransomware_encrypt_p_of_success):
|
||||
self.sys_log.info(f"{self.name}: Attempting to launch payload")
|
||||
if not len(self._host_db_client.connections):
|
||||
self._host_db_client.connect()
|
||||
if len(self._host_db_client.connections):
|
||||
self._host_db_client.query(self.payload)
|
||||
if not self._db_connection:
|
||||
self._establish_db_connection()
|
||||
if self._db_connection:
|
||||
attack_successful = self._db_connection.query(self.payload)
|
||||
self.sys_log.info(f"{self.name} Payload delivered: {self.payload}")
|
||||
attack_successful = True
|
||||
if attack_successful:
|
||||
self.sys_log.info(f"{self.name}: Payload Successful")
|
||||
self.attack_stage = RansomwareAttackStage.SUCCEEDED
|
||||
@@ -289,7 +291,7 @@ class RansomwareScript(Application):
|
||||
self.sys_log.info(f"{self.name}: Payload failed")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
else:
|
||||
self.sys_log.error("Attack Attempted to launch too quickly")
|
||||
self.sys_log.warning("Attack Attempted to launch too quickly")
|
||||
self.attack_stage = RansomwareAttackStage.FAILED
|
||||
|
||||
def _local_download(self):
|
||||
|
||||
@@ -97,7 +97,7 @@ class WebBrowser(Application):
|
||||
try:
|
||||
parsed_url = urlparse(url)
|
||||
except Exception:
|
||||
self.sys_log.error(f"{url} is not a valid URL")
|
||||
self.sys_log.warning(f"{url} is not a valid URL")
|
||||
return False
|
||||
|
||||
# get the IP address of the domain name via DNS
|
||||
@@ -114,7 +114,7 @@ class WebBrowser(Application):
|
||||
self.domain_name_ip_address = IPv4Address(parsed_url.hostname)
|
||||
except Exception:
|
||||
# unable to deal with this request
|
||||
self.sys_log.error(f"{self.name}: Unable to resolve URL {url}")
|
||||
self.sys_log.warning(f"{self.name}: Unable to resolve URL {url}")
|
||||
return False
|
||||
|
||||
# create HTTPRequest payload
|
||||
@@ -140,7 +140,8 @@ class WebBrowser(Application):
|
||||
)
|
||||
return self.latest_response.status_code is HttpStatusCode.OK
|
||||
else:
|
||||
self.sys_log.error(f"Error sending Http Packet {str(payload)}")
|
||||
self.sys_log.warning(f"{self.name}: Error sending Http Packet")
|
||||
self.sys_log.debug(f"{self.name}: {payload=}")
|
||||
self.history.append(
|
||||
WebBrowser.BrowserHistoryItem(
|
||||
url=url, status=self.BrowserHistoryItem._HistoryItemStatus.SERVER_UNREACHABLE
|
||||
@@ -181,7 +182,8 @@ class WebBrowser(Application):
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
if not isinstance(payload, HttpResponsePacket):
|
||||
self.sys_log.error(f"{self.name} received a packet that is not an HttpResponsePacket")
|
||||
self.sys_log.warning(f"{self.name} received a packet that is not an HttpResponsePacket")
|
||||
self.sys_log.debug(f"{self.name}: {payload=}")
|
||||
return False
|
||||
self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}")
|
||||
self.latest_response = payload
|
||||
|
||||
@@ -87,7 +87,7 @@ class SoftwareManager:
|
||||
# TODO: Software manager and node itself both have an install method. Need to refactor to have more logical
|
||||
# separation of concerns.
|
||||
if software_class in self._software_class_to_name_map:
|
||||
self.sys_log.info(f"Cannot install {software_class} as it is already installed")
|
||||
self.sys_log.warning(f"Cannot install {software_class} as it is already installed")
|
||||
return
|
||||
software = software_class(
|
||||
software_manager=self, sys_log=self.sys_log, file_system=self.file_system, dns_server=self.dns_server
|
||||
@@ -97,7 +97,6 @@ class SoftwareManager:
|
||||
software.software_manager = self
|
||||
self.software[software.name] = software
|
||||
self.port_protocol_mapping[(software.port, software.protocol)] = software
|
||||
self.sys_log.info(f"Installed {software.name}")
|
||||
if isinstance(software, Application):
|
||||
software.operating_state = ApplicationOperatingState.CLOSED
|
||||
|
||||
@@ -114,6 +113,7 @@ class SoftwareManager:
|
||||
:param software_name: The software name.
|
||||
"""
|
||||
if software_name in self.software:
|
||||
self.software[software_name].uninstall()
|
||||
software = self.software.pop(software_name) # noqa
|
||||
if isinstance(software, Application):
|
||||
self.node.uninstall_application(software)
|
||||
@@ -144,7 +144,7 @@ class SoftwareManager:
|
||||
if receiver:
|
||||
receiver.receive_payload(payload)
|
||||
else:
|
||||
self.sys_log.error(f"No Service of Application found with the name {target_software}")
|
||||
self.sys_log.warning(f"No Service of Application found with the name {target_software}")
|
||||
|
||||
def send_payload_to_session_manager(
|
||||
self,
|
||||
@@ -196,7 +196,7 @@ class SoftwareManager:
|
||||
payload=payload, session_id=session_id, from_network_interface=from_network_interface, frame=frame
|
||||
)
|
||||
else:
|
||||
self.sys_log.error(f"No service or application found for port {port} and protocol {protocol}")
|
||||
self.sys_log.warning(f"No service or application found for port {port} and protocol {protocol}")
|
||||
pass
|
||||
|
||||
def show(self, markdown: bool = False):
|
||||
|
||||
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator import LogLevel, SIM_OUTPUT
|
||||
|
||||
|
||||
class _NotJSONFilter(logging.Filter):
|
||||
@@ -52,6 +52,7 @@ class SysLog:
|
||||
file_handler.setFormatter(logging.Formatter(log_format))
|
||||
|
||||
self.logger = logging.getLogger(f"{self.hostname}_sys_log")
|
||||
self.logger.handlers.clear() # clear handlers
|
||||
self.logger.setLevel(logging.DEBUG)
|
||||
self.logger.addHandler(file_handler)
|
||||
|
||||
@@ -99,6 +100,9 @@ class SysLog:
|
||||
:param msg: The message to be logged.
|
||||
:param to_terminal: If True, prints to the terminal too.
|
||||
"""
|
||||
if SIM_OUTPUT.sys_log_level > LogLevel.DEBUG:
|
||||
return
|
||||
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.debug(msg)
|
||||
self._write_to_terminal(msg, "DEBUG", to_terminal)
|
||||
@@ -110,6 +114,9 @@ class SysLog:
|
||||
:param msg: The message to be logged.
|
||||
:param to_terminal: If True, prints to the terminal too.
|
||||
"""
|
||||
if SIM_OUTPUT.sys_log_level > LogLevel.INFO:
|
||||
return
|
||||
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.info(msg)
|
||||
self._write_to_terminal(msg, "INFO", to_terminal)
|
||||
@@ -121,6 +128,9 @@ class SysLog:
|
||||
:param msg: The message to be logged.
|
||||
:param to_terminal: If True, prints to the terminal too.
|
||||
"""
|
||||
if SIM_OUTPUT.sys_log_level > LogLevel.WARNING:
|
||||
return
|
||||
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.warning(msg)
|
||||
self._write_to_terminal(msg, "WARNING", to_terminal)
|
||||
@@ -132,6 +142,9 @@ class SysLog:
|
||||
:param msg: The message to be logged.
|
||||
:param to_terminal: If True, prints to the terminal too.
|
||||
"""
|
||||
if SIM_OUTPUT.sys_log_level > LogLevel.ERROR:
|
||||
return
|
||||
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.error(msg)
|
||||
self._write_to_terminal(msg, "ERROR", to_terminal)
|
||||
@@ -143,6 +156,9 @@ class SysLog:
|
||||
:param msg: The message to be logged.
|
||||
:param to_terminal: If True, prints to the terminal too.
|
||||
"""
|
||||
if LogLevel.CRITICAL < SIM_OUTPUT.sys_log_level:
|
||||
return
|
||||
|
||||
if SIM_OUTPUT.save_sys_logs:
|
||||
self.logger.critical(msg)
|
||||
self._write_to_terminal(msg, "CRITICAL", to_terminal)
|
||||
|
||||
@@ -147,7 +147,7 @@ class ARP(Service):
|
||||
payload=arp_packet, dst_ip_address=target_ip_address, dst_port=self.port, ip_protocol=self.protocol
|
||||
)
|
||||
else:
|
||||
self.sys_log.error(
|
||||
self.sys_log.warning(
|
||||
"Cannot send ARP request as there is no outbound Network Interface to use. Try configuring the default "
|
||||
"gateway."
|
||||
)
|
||||
@@ -173,7 +173,7 @@ class ARP(Service):
|
||||
ip_protocol=self.protocol,
|
||||
)
|
||||
else:
|
||||
self.sys_log.error(
|
||||
self.sys_log.warning(
|
||||
"Cannot send ARP reply as there is no outbound Network Interface to use. Try configuring the default "
|
||||
"gateway."
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Literal, Optional, Union
|
||||
from uuid import uuid4
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.simulator.file_system.file_system import File
|
||||
@@ -57,7 +58,7 @@ class DatabaseService(Service):
|
||||
|
||||
# check if the backup server was configured
|
||||
if self.backup_server_ip is None:
|
||||
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
self.sys_log.warning(f"{self.name} - {self.sys_log.hostname}: not configured.")
|
||||
return False
|
||||
|
||||
software_manager: SoftwareManager = self.software_manager
|
||||
@@ -110,7 +111,7 @@ class DatabaseService(Service):
|
||||
db_file = self.file_system.get_file(folder_name="database", file_name="database.db", include_deleted=True)
|
||||
|
||||
if db_file is None:
|
||||
self.sys_log.error("Database file not initialised.")
|
||||
self.sys_log.warning("Database file not initialised.")
|
||||
return False
|
||||
|
||||
# if the file was deleted, get the old visible health state
|
||||
@@ -145,8 +146,16 @@ class DatabaseService(Service):
|
||||
"""Returns the database folder."""
|
||||
return self.file_system.get_folder_by_id(self.db_file.folder_id)
|
||||
|
||||
def _generate_connection_id(self) -> str:
|
||||
"""Generate a unique connection ID."""
|
||||
return str(uuid4())
|
||||
|
||||
def _process_connect(
|
||||
self, connection_id: str, password: Optional[str] = None
|
||||
self,
|
||||
src_ip: IPv4Address,
|
||||
connection_request_id: str,
|
||||
password: Optional[str] = None,
|
||||
session_id: Optional[str] = None,
|
||||
) -> Dict[str, Union[int, Dict[str, bool]]]:
|
||||
"""Process an incoming connection request.
|
||||
|
||||
@@ -158,24 +167,24 @@ class DatabaseService(Service):
|
||||
:rtype: Dict[str, Union[int, Dict[str, bool]]]
|
||||
"""
|
||||
status_code = 500 # Default internal server error
|
||||
connection_id = None
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
status_code = 503 # service unavailable
|
||||
if self.health_state_actual == SoftwareHealthState.OVERWHELMED:
|
||||
self.sys_log.error(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity."
|
||||
)
|
||||
self.sys_log.error(f"{self.name}: Connect request for {src_ip=} declined. Service is at capacity.")
|
||||
if self.health_state_actual == SoftwareHealthState.GOOD:
|
||||
if self.password == password:
|
||||
status_code = 200 # ok
|
||||
connection_id = self._generate_connection_id()
|
||||
# try to create connection
|
||||
if not self.add_connection(connection_id=connection_id):
|
||||
if not self.add_connection(connection_id=connection_id, session_id=session_id):
|
||||
status_code = 500
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
else:
|
||||
status_code = 401 # Unauthorised
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined")
|
||||
else:
|
||||
status_code = 404 # service not found
|
||||
return {
|
||||
@@ -183,6 +192,7 @@ class DatabaseService(Service):
|
||||
"type": "connect_response",
|
||||
"response": status_code == 200,
|
||||
"connection_id": connection_id,
|
||||
"connection_request_id": connection_request_id,
|
||||
}
|
||||
|
||||
def _process_sql(
|
||||
@@ -206,7 +216,7 @@ class DatabaseService(Service):
|
||||
self.sys_log.info(f"{self.name}: Running {query}")
|
||||
|
||||
if not self.db_file:
|
||||
self.sys_log.info(f"{self.name}: Failed to run {query} because the database file is missing.")
|
||||
self.sys_log.error(f"{self.name}: Failed to run {query} because the database file is missing.")
|
||||
return {"status_code": 404, "type": "sql", "data": False}
|
||||
|
||||
if query == "SELECT":
|
||||
@@ -276,7 +286,7 @@ class DatabaseService(Service):
|
||||
return {"status_code": 401, "data": False}
|
||||
else:
|
||||
# Invalid query
|
||||
self.sys_log.info(f"{self.name}: Invalid {query}")
|
||||
self.sys_log.warning(f"{self.name}: Invalid {query}")
|
||||
return {"status_code": 500, "data": False}
|
||||
|
||||
def describe_state(self) -> Dict:
|
||||
@@ -299,19 +309,34 @@ class DatabaseService(Service):
|
||||
:return: True if the Status Code is 200, otherwise False.
|
||||
"""
|
||||
result = {"status_code": 500, "data": []}
|
||||
|
||||
# if server service is down, return error
|
||||
if not self._can_perform_action():
|
||||
return False
|
||||
|
||||
if isinstance(payload, dict) and payload.get("type"):
|
||||
if payload["type"] == "connect_request":
|
||||
src_ip = kwargs.get("frame").ip.src_ip_address
|
||||
result = self._process_connect(
|
||||
connection_id=payload.get("connection_id"), password=payload.get("password")
|
||||
src_ip=src_ip,
|
||||
password=payload.get("password"),
|
||||
connection_request_id=payload.get("connection_request_id"),
|
||||
session_id=session_id,
|
||||
)
|
||||
elif payload["type"] == "disconnect":
|
||||
if payload["connection_id"] in self.connections:
|
||||
self.remove_connection(connection_id=payload["connection_id"])
|
||||
connection_id = payload["connection_id"]
|
||||
connected_ip_address = self.connections[connection_id]["ip_address"]
|
||||
frame = kwargs.get("frame")
|
||||
if connected_ip_address == frame.ip.src_ip_address:
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Received disconnect command for {connection_id=} from {connected_ip_address}"
|
||||
)
|
||||
self.terminate_connection(connection_id=payload["connection_id"], send_disconnect=False)
|
||||
else:
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: Ignoring disconnect command for {connection_id=} as the command source "
|
||||
f"({frame.ip.src_ip_address}) doesn't match the connection source ({connected_ip_address})"
|
||||
)
|
||||
elif payload["type"] == "sql":
|
||||
if payload.get("connection_id") in self.connections:
|
||||
result = self._process_sql(
|
||||
|
||||
@@ -72,7 +72,7 @@ class DNSClient(Service):
|
||||
|
||||
# check if DNS server is configured
|
||||
if self.dns_server is None:
|
||||
self.sys_log.error(f"{self.name}: DNS Server is not configured")
|
||||
self.sys_log.warning(f"{self.name}: DNS Server is not configured")
|
||||
return False
|
||||
|
||||
# check if the target domain is in the client's DNS cache
|
||||
@@ -88,7 +88,7 @@ class DNSClient(Service):
|
||||
else:
|
||||
# return False if already reattempted
|
||||
if is_reattempt:
|
||||
self.sys_log.info(f"{self.name}: Domain lookup for {target_domain} failed")
|
||||
self.sys_log.warning(f"{self.name}: Domain lookup for {target_domain} failed")
|
||||
return False
|
||||
else:
|
||||
# send a request to check if domain name exists in the DNS Server
|
||||
@@ -143,7 +143,8 @@ class DNSClient(Service):
|
||||
"""
|
||||
# The payload should be a DNS packet
|
||||
if not isinstance(payload, DNSPacket):
|
||||
_LOGGER.debug(f"{payload} is not a DNSPacket")
|
||||
self.sys_log.warning(f"{self.name}: Payload is not a DNSPacket")
|
||||
self.sys_log.debug(f"{self.name}: {payload}")
|
||||
return False
|
||||
|
||||
if payload.dns_reply is not None:
|
||||
@@ -156,5 +157,5 @@ class DNSClient(Service):
|
||||
self.dns_cache[payload.dns_request.domain_name_request] = payload.dns_reply.domain_name_ip_address
|
||||
return True
|
||||
|
||||
self.sys_log.error(f"Failed to resolve domain name {payload.dns_request.domain_name_request}")
|
||||
self.sys_log.warning(f"Failed to resolve domain name {payload.dns_request.domain_name_request}")
|
||||
return False
|
||||
|
||||
@@ -90,7 +90,8 @@ class DNSServer(Service):
|
||||
|
||||
# The payload should be a DNS packet
|
||||
if not isinstance(payload, DNSPacket):
|
||||
_LOGGER.debug(f"{payload} is not a DNSPacket")
|
||||
self.sys_log.warning(f"{payload} is not a DNSPacket")
|
||||
self.sys_log.debug(f"{payload} is not a DNSPacket")
|
||||
return False
|
||||
|
||||
# cast payload into a DNS packet
|
||||
|
||||
@@ -82,7 +82,7 @@ class FTPClient(FTPServiceABC):
|
||||
else:
|
||||
if is_reattempt:
|
||||
# reattempt failed
|
||||
self.sys_log.info(
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: Unable to connect to FTP Server "
|
||||
f"{dest_ip_address} via port {payload.ftp_command_args.value}"
|
||||
)
|
||||
@@ -93,7 +93,7 @@ class FTPClient(FTPServiceABC):
|
||||
dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id, is_reattempt=True
|
||||
)
|
||||
else:
|
||||
self.sys_log.error(f"{self.name}: Unable to send FTPPacket")
|
||||
self.sys_log.warning(f"{self.name}: Unable to send FTPPacket")
|
||||
return False
|
||||
|
||||
def _disconnect_from_server(
|
||||
@@ -158,7 +158,7 @@ class FTPClient(FTPServiceABC):
|
||||
# check if the file to transfer exists on the client
|
||||
file_to_transfer: File = self.file_system.get_file(folder_name=src_folder_name, file_name=src_file_name)
|
||||
if not file_to_transfer:
|
||||
self.sys_log.error(f"Unable to send file that does not exist: {src_folder_name}/{src_file_name}")
|
||||
self.sys_log.warning(f"Unable to send file that does not exist: {src_folder_name}/{src_file_name}")
|
||||
return False
|
||||
|
||||
# check if FTP is currently connected to IP
|
||||
@@ -253,7 +253,8 @@ class FTPClient(FTPServiceABC):
|
||||
:type: session_id: Optional[str]
|
||||
"""
|
||||
if not isinstance(payload, FTPPacket):
|
||||
self.sys_log.error(f"{payload} is not an FTP packet")
|
||||
self.sys_log.warning(f"{self.name}: Payload is not an FTP packet")
|
||||
self.sys_log.debug(f"{self.name}: {payload}")
|
||||
return False
|
||||
|
||||
"""
|
||||
@@ -275,7 +276,7 @@ class FTPClient(FTPServiceABC):
|
||||
|
||||
# if QUIT succeeded, remove the session from active connection list
|
||||
if payload.ftp_command is FTPCommand.QUIT and payload.status_code is FTPStatusCode.OK:
|
||||
self.remove_connection(connection_id=session_id)
|
||||
self.terminate_connection(connection_id=session_id)
|
||||
|
||||
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
|
||||
|
||||
|
||||
@@ -61,7 +61,7 @@ class FTPServer(FTPServiceABC):
|
||||
return payload
|
||||
|
||||
if payload.ftp_command == FTPCommand.QUIT:
|
||||
self.remove_connection(connection_id=session_id)
|
||||
self.terminate_connection(connection_id=session_id)
|
||||
payload.status_code = FTPStatusCode.OK
|
||||
return payload
|
||||
|
||||
@@ -70,7 +70,8 @@ class FTPServer(FTPServiceABC):
|
||||
def receive(self, payload: Any, session_id: Optional[str] = None, **kwargs) -> bool:
|
||||
"""Receives a payload from the SessionManager."""
|
||||
if not isinstance(payload, FTPPacket):
|
||||
self.sys_log.error(f"{payload} is not an FTP packet")
|
||||
self.sys_log.warning(f"{self.name}: Payload is not an FTP packet")
|
||||
self.sys_log.debug(f"{self.name}: {payload}")
|
||||
return False
|
||||
|
||||
if not super().receive(payload=payload, session_id=session_id, **kwargs):
|
||||
|
||||
@@ -95,7 +95,7 @@ class ICMP(Service):
|
||||
network_interface = self.software_manager.session_manager.resolve_outbound_network_interface(target_ip_address)
|
||||
|
||||
if not network_interface:
|
||||
self.sys_log.error(
|
||||
self.sys_log.warning(
|
||||
"Cannot send ICMP echo request as there is no outbound Network Interface to use. Try configuring the "
|
||||
"default gateway."
|
||||
)
|
||||
@@ -130,7 +130,7 @@ class ICMP(Service):
|
||||
)
|
||||
|
||||
if not network_interface:
|
||||
self.sys_log.error(
|
||||
self.sys_log.warning(
|
||||
"Cannot send ICMP echo reply as there is no outbound Network Interface to use. Try configuring the "
|
||||
"default gateway."
|
||||
)
|
||||
|
||||
@@ -87,7 +87,7 @@ class NTPClient(Service):
|
||||
:return: True if successful, False otherwise.
|
||||
"""
|
||||
if not isinstance(payload, NTPPacket):
|
||||
_LOGGER.debug(f"{self.name}: Failed to parse NTP update")
|
||||
self.sys_log.warning(f"{self.name}: Failed to parse NTP update")
|
||||
return False
|
||||
if payload.ntp_reply.ntp_datetime:
|
||||
self.time = payload.ntp_reply.ntp_datetime
|
||||
@@ -115,7 +115,6 @@ class NTPClient(Service):
|
||||
:param timestep: The current timestep number. (Amount of time since simulation episode began)
|
||||
:type timestep: int
|
||||
"""
|
||||
self.sys_log.info(f"{self.name} apply_timestep")
|
||||
super().apply_timestep(timestep)
|
||||
if self.operating_state == ServiceOperatingState.RUNNING:
|
||||
# request time from server
|
||||
|
||||
@@ -51,7 +51,8 @@ class NTPServer(Service):
|
||||
:return: True if valid NTP request else False.
|
||||
"""
|
||||
if not (isinstance(payload, NTPPacket)):
|
||||
_LOGGER.debug(f"{payload} is not a NTPPacket")
|
||||
self.sys_log.warning(f"{self.name}: Payload is not a NTPPacket")
|
||||
self.sys_log.debug(f"{self.name}: {payload}")
|
||||
return False
|
||||
payload: NTPPacket = payload
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ class Service(IOSoftware):
|
||||
|
||||
if self.operating_state is not ServiceOperatingState.RUNNING:
|
||||
# service is not running
|
||||
_LOGGER.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
self.sys_log.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}")
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -187,6 +187,6 @@ class Service(IOSoftware):
|
||||
super().apply_timestep(timestep)
|
||||
if self.operating_state == ServiceOperatingState.RESTARTING:
|
||||
if self.restart_countdown <= 0:
|
||||
_LOGGER.debug(f"Restarting finished for service {self.name}")
|
||||
self.sys_log.debug(f"Restarting finished for service {self.name}")
|
||||
self.operating_state = ServiceOperatingState.RUNNING
|
||||
self.restart_countdown -= 1
|
||||
|
||||
@@ -11,7 +11,7 @@ from primaite.simulator.network.protocols.http import (
|
||||
)
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClientConnection
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
@@ -48,6 +48,7 @@ class WebServer(Service):
|
||||
super().__init__(**kwargs)
|
||||
self._install_web_files()
|
||||
self.start()
|
||||
self.db_connection: Optional[DatabaseClientConnection] = None
|
||||
|
||||
def _install_web_files(self):
|
||||
"""
|
||||
@@ -108,9 +109,11 @@ class WebServer(Service):
|
||||
|
||||
if path.startswith("users"):
|
||||
# get data from DatabaseServer
|
||||
db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
|
||||
# get all users
|
||||
if db_client.query("SELECT"):
|
||||
if not self.db_connection:
|
||||
self._establish_db_connection()
|
||||
|
||||
if self.db_connection.query("SELECT"):
|
||||
# query succeeded
|
||||
self.set_health_state(SoftwareHealthState.GOOD)
|
||||
response.status_code = HttpStatusCode.OK
|
||||
@@ -123,6 +126,11 @@ class WebServer(Service):
|
||||
response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR
|
||||
return response
|
||||
|
||||
def _establish_db_connection(self) -> None:
|
||||
"""Establish a connection to db."""
|
||||
db_client = self.software_manager.software.get("DatabaseClient")
|
||||
self.db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
def send(
|
||||
self,
|
||||
payload: HttpResponsePacket,
|
||||
@@ -167,7 +175,8 @@ class WebServer(Service):
|
||||
|
||||
# check if the payload is an HTTPPacket
|
||||
if not isinstance(payload, HttpRequestPacket):
|
||||
self.sys_log.error("Payload is not an HTTPPacket")
|
||||
self.sys_log.warning(f"{self.name}: Payload is not an HTTPPacket")
|
||||
self.sys_log.debug(f"{self.name}: {payload}")
|
||||
return False
|
||||
|
||||
return self._process_http_request(payload=payload, session_id=session_id)
|
||||
|
||||
@@ -5,8 +5,10 @@ from enum import Enum
|
||||
from ipaddress import IPv4Address, IPv4Network
|
||||
from typing import Any, Dict, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from prettytable import MARKDOWN, PrettyTable
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.core import RequestManager, RequestType, SimComponent
|
||||
from primaite.simulator.file_system.file_system import FileSystem, Folder
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.transmission.network_layer import IPProtocol
|
||||
@@ -287,7 +289,9 @@ class IOSoftware(Software):
|
||||
Returns true if the software can perform actions.
|
||||
"""
|
||||
if self.software_manager and self.software_manager.node.operating_state != NodeOperatingState.ON:
|
||||
_LOGGER.debug(f"{self.name} Error: {self.software_manager.node.hostname} is not online.")
|
||||
self.software_manager.node.sys_log.error(
|
||||
f"{self.name} Error: {self.software_manager.node.hostname} is not online."
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
@@ -296,7 +300,7 @@ class IOSoftware(Software):
|
||||
"""Return the public version of connections."""
|
||||
return copy.copy(self._connections)
|
||||
|
||||
def add_connection(self, connection_id: str, session_id: Optional[str] = None) -> bool:
|
||||
def add_connection(self, connection_id: Union[str, int], session_id: Optional[str] = None) -> bool:
|
||||
"""
|
||||
Create a new connection to this service.
|
||||
|
||||
@@ -308,7 +312,7 @@ class IOSoftware(Software):
|
||||
# if over or at capacity, set to overwhelmed
|
||||
if len(self._connections) >= self.max_sessions:
|
||||
self.set_health_state(SoftwareHealthState.OVERWHELMED)
|
||||
self.sys_log.error(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
self.sys_log.warning(f"{self.name}: Connect request for {connection_id=} declined. Service is at capacity.")
|
||||
return False
|
||||
else:
|
||||
# if service was previously overwhelmed, set to good because there is enough space for connections
|
||||
@@ -321,30 +325,53 @@ class IOSoftware(Software):
|
||||
if session_id:
|
||||
session_details = self._get_session_details(session_id)
|
||||
self._connections[connection_id] = {
|
||||
"session_id": session_id,
|
||||
"ip_address": session_details.with_ip_address if session_details else None,
|
||||
"time": datetime.now(),
|
||||
}
|
||||
self.sys_log.info(f"{self.name}: Connect request for {connection_id=} authorised")
|
||||
return True
|
||||
# connection with given id already exists
|
||||
self.sys_log.error(
|
||||
self.sys_log.warning(
|
||||
f"{self.name}: Connect request for {connection_id=} declined. Connection already exists."
|
||||
)
|
||||
return False
|
||||
|
||||
def remove_connection(self, connection_id: str) -> bool:
|
||||
def terminate_connection(self, connection_id: str, send_disconnect: bool = True) -> bool:
|
||||
"""
|
||||
Remove a connection from this service.
|
||||
Terminates a connection from this service.
|
||||
|
||||
Returns true if connection successfully removed
|
||||
|
||||
:param: connection_id: UUID of the connection to create
|
||||
:param send_disconnect: If True, sends a disconnect payload to the ip address of the associated session.
|
||||
:type: string
|
||||
"""
|
||||
if self.connections.get(connection_id):
|
||||
self._connections.pop(connection_id)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} closed.")
|
||||
return True
|
||||
connection_dict = self._connections.pop(connection_id)
|
||||
if send_disconnect:
|
||||
self.software_manager.send_payload_to_session_manager(
|
||||
payload={"type": "disconnect", "connection_id": connection_id},
|
||||
session_id=connection_dict["session_id"],
|
||||
)
|
||||
self.sys_log.info(f"{self.name}: Connection {connection_id=} terminated")
|
||||
return True
|
||||
return False
|
||||
|
||||
def show_connections(self, markdown: bool = False):
|
||||
"""
|
||||
Display the connections in tabular format.
|
||||
|
||||
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
|
||||
"""
|
||||
table = PrettyTable(["IP Address", "Connection ID", "Creation Timestamp"])
|
||||
if markdown:
|
||||
table.set_style(MARKDOWN)
|
||||
table.align = "l"
|
||||
table.title = f"{self.sys_log.hostname} {self.name} Connections"
|
||||
for connection_id, data in self.connections.items():
|
||||
table.add_row([data["ip_address"], connection_id, str(data["time"])])
|
||||
print(table.get_string(sortby="Creation Timestamp"))
|
||||
|
||||
def clear_connections(self):
|
||||
"""Clears all the connections from the software."""
|
||||
|
||||
11
src/primaite/utils/primaite_config_utils.py
Normal file
11
src/primaite/utils/primaite_config_utils.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import yaml
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
|
||||
def is_dev_mode() -> bool:
|
||||
"""Returns True if PrimAITE is currently running in developer mode."""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
return primaite_config["developer_mode"]
|
||||
@@ -8,6 +8,7 @@ io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
save_sys_logs: true
|
||||
sys_log_level: WARNING
|
||||
|
||||
|
||||
game:
|
||||
@@ -60,6 +61,102 @@ agents:
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
- hostname: client_3
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.23
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- switch_1:eth-1<->client_1:eth-1
|
||||
- switch_1:eth-2<->client_2:eth-1
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
options:
|
||||
nodes:
|
||||
- node_name: switch
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
- node_name: client_3
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_list:
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.23
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: web_server
|
||||
service_name: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: true
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
@@ -75,6 +172,7 @@ simulation:
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: RansomwareScript
|
||||
- type: WebBrowser
|
||||
options:
|
||||
target_url: http://arcd.com/users/
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# No green agents present
|
||||
greens: &greens []
|
||||
@@ -0,0 +1,34 @@
|
||||
agents: &greens
|
||||
- ref: green_A
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.2
|
||||
1: 0.8
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client
|
||||
@@ -0,0 +1,34 @@
|
||||
agents: &greens
|
||||
- ref: green_B
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.95
|
||||
1: 0.05
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client
|
||||
@@ -0,0 +1,2 @@
|
||||
# No red agents present
|
||||
reds: &reds []
|
||||
26
tests/assets/configs/scenario_with_placeholders/reds_1.yaml
Normal file
26
tests/assets/configs/scenario_with_placeholders/reds_1.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
reds: &reds
|
||||
- ref: red_A
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 10
|
||||
frequency: 10
|
||||
variance: 0
|
||||
26
tests/assets/configs/scenario_with_placeholders/reds_2.yaml
Normal file
26
tests/assets/configs/scenario_with_placeholders/reds_2.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
reds: &reds
|
||||
- ref: red_B
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 3
|
||||
frequency: 2
|
||||
variance: 1
|
||||
168
tests/assets/configs/scenario_with_placeholders/scenario.yaml
Normal file
168
tests/assets/configs/scenario_with_placeholders/scenario.yaml
Normal file
@@ -0,0 +1,168 @@
|
||||
io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: false
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 128
|
||||
ports:
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
- *greens
|
||||
- *reds
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
routers: []
|
||||
hosts:
|
||||
- hostname: client
|
||||
- hostname: server
|
||||
num_services: 1
|
||||
num_applications: 1
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 1
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- client:eth-1<->switch_1:eth-1
|
||||
- server:eth-1<->switch_1:eth-2
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_SHUTDOWN
|
||||
options:
|
||||
node_id: 0
|
||||
2:
|
||||
action: NODE_SHUTDOWN
|
||||
options:
|
||||
node_id: 1
|
||||
3:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 0
|
||||
4:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 1
|
||||
5:
|
||||
action: HOST_NIC_DISABLE
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
6:
|
||||
action: HOST_NIC_DISABLE
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
7:
|
||||
action: HOST_NIC_ENABLE
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
8:
|
||||
action: HOST_NIC_ENABLE
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
- node_name: server
|
||||
|
||||
max_folders_per_node: 0
|
||||
max_files_per_folder: 0
|
||||
max_services_per_node: 0
|
||||
max_nics_per_node: 1
|
||||
max_acl_rules: 0
|
||||
ip_list:
|
||||
- 192.168.1.2
|
||||
- 192.168.1.3
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.40
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: false
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: client
|
||||
type: computer
|
||||
ip_address: 192.168.1.2
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
applications:
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.3
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
server_ip: 192.168.1.3
|
||||
payload: "DELETE"
|
||||
|
||||
- hostname: switch_1
|
||||
type: switch
|
||||
num_ports: 2
|
||||
|
||||
- hostname: server
|
||||
type: server
|
||||
ip_address: 192.168.1.3
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- type: DatabaseService
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: client
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
@@ -0,0 +1,14 @@
|
||||
base_scenario: scenario.yaml
|
||||
schedule:
|
||||
0:
|
||||
- greens_0.yaml
|
||||
- reds_0.yaml
|
||||
1:
|
||||
- greens_0.yaml
|
||||
- reds_1.yaml
|
||||
2:
|
||||
- greens_1.yaml
|
||||
- reds_1.yaml
|
||||
3:
|
||||
- greens_2.yaml
|
||||
- reds_2.yaml
|
||||
@@ -16,7 +16,7 @@ def test_sb3_compatibility():
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
gym = PrimaiteGymEnv(game_config=cfg)
|
||||
gym = PrimaiteGymEnv(env_config=cfg)
|
||||
model = PPO("MlpPolicy", gym)
|
||||
|
||||
model.learn(total_timesteps=1000)
|
||||
|
||||
@@ -21,7 +21,7 @@ class TestPrimaiteEnvironment:
|
||||
"""Check that environment loads correctly from config and it can be reset."""
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
def env_checks():
|
||||
assert env is not None
|
||||
@@ -44,7 +44,7 @@ class TestPrimaiteEnvironment:
|
||||
"""Make sure you can go all the way through the session without errors."""
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
assert (num_actions := len(env.agent.action_manager.action_map)) == 54
|
||||
# run every action and make sure there's no crash
|
||||
@@ -88,4 +88,4 @@ class TestPrimaiteEnvironment:
|
||||
with open(MISCONFIGURED_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
@@ -4,7 +4,7 @@ from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
@@ -20,23 +20,23 @@ def test_data_manipulation(uc2_network):
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
db_service.backup_database()
|
||||
|
||||
# First check that the DB client on the web_server can successfully query the users table on the database
|
||||
assert db_client.query("SELECT")
|
||||
assert db_connection.query("SELECT")
|
||||
|
||||
# Now we run the DataManipulationBot
|
||||
db_manipulation_bot.attack()
|
||||
|
||||
# Now check that the DB client on the web_server cannot query the users table on the database
|
||||
assert not db_client.query("SELECT")
|
||||
assert not db_connection.query("SELECT")
|
||||
|
||||
# Now restore the database
|
||||
db_service.restore_backup()
|
||||
|
||||
# Now check that the DB client on the web_server can successfully query the users table on the database
|
||||
assert db_client.query("SELECT")
|
||||
assert db_connection.query("SELECT")
|
||||
|
||||
|
||||
def test_application_install_uninstall_on_uc2():
|
||||
@@ -44,7 +44,7 @@ def test_application_install_uninstall_on_uc2():
|
||||
with open(TEST_ASSETS_ROOT / "configs/test_application_install.yaml", "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
env.agent.flatten_obs = False
|
||||
env.reset()
|
||||
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from tests.conftest import TEST_ASSETS_ROOT
|
||||
|
||||
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"
|
||||
single_yaml_config = TEST_ASSETS_ROOT / "configs" / "test_primaite_session.yaml"
|
||||
with open(single_yaml_config, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_type", [PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv])
|
||||
def test_creating_env_with_folder(env_type):
|
||||
"""Check that the environment can be created with a folder path."""
|
||||
|
||||
def check_taking_steps(e):
|
||||
if isinstance(e, PrimaiteRayMARLEnv):
|
||||
for i in range(9):
|
||||
e.step({k: i for k in e.game.rl_agents})
|
||||
else:
|
||||
for i in range(9):
|
||||
e.step(i)
|
||||
|
||||
env = env_type(env_config=folder_path)
|
||||
assert env is not None
|
||||
for _ in range(3): # do it multiple times to ensure it loops back to the beginning
|
||||
assert len(env.game.agents) == 1
|
||||
assert "defender" in env.game.agents
|
||||
check_taking_steps(env)
|
||||
|
||||
env.reset()
|
||||
assert len(env.game.agents) == 2
|
||||
assert "defender" in env.game.agents
|
||||
assert "red_A" in env.game.agents
|
||||
check_taking_steps(env)
|
||||
|
||||
env.reset()
|
||||
assert len(env.game.agents) == 3
|
||||
assert all([a in env.game.agents for a in ["defender", "green_A", "red_A"]])
|
||||
check_taking_steps(env)
|
||||
|
||||
env.reset()
|
||||
assert len(env.game.agents) == 3
|
||||
assert all([a in env.game.agents for a in ["defender", "green_B", "red_B"]])
|
||||
check_taking_steps(env)
|
||||
|
||||
env.reset()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_data, env_type",
|
||||
[
|
||||
(single_yaml_config, PrimaiteGymEnv),
|
||||
(single_yaml_config, PrimaiteRayEnv),
|
||||
(single_yaml_config, PrimaiteRayMARLEnv),
|
||||
(config_dict, PrimaiteGymEnv),
|
||||
(config_dict, PrimaiteRayEnv),
|
||||
(config_dict, PrimaiteRayMARLEnv),
|
||||
],
|
||||
)
|
||||
def test_creating_env_with_static_config(env_data, env_type):
|
||||
"""Check that the environment can be created with a single yaml file."""
|
||||
env = env_type(env_config=single_yaml_config)
|
||||
assert env is not None
|
||||
agents_before = len(env.game.agents)
|
||||
env.reset()
|
||||
assert len(env.game.agents) == agents_before
|
||||
@@ -0,0 +1,36 @@
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteGymEnv
|
||||
from primaite.simulator import LogLevel
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml"
|
||||
|
||||
|
||||
def load_config(config_path: Union[str, Path]) -> PrimaiteGame:
|
||||
"""Returns a PrimaiteGame object which loads the contents of a given yaml path."""
|
||||
with open(config_path, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
return PrimaiteGame.from_config(cfg)
|
||||
|
||||
|
||||
def test_io_settings():
|
||||
"""Test that the io_settings are loaded correctly."""
|
||||
with open(BASIC_CONFIG, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
assert env.io.settings is not None
|
||||
|
||||
assert env.io.settings.sys_log_level is LogLevel.WARNING
|
||||
assert env.io.settings.save_pcap_logs
|
||||
assert env.io.settings.save_sys_logs
|
||||
assert env.io.settings.save_step_metadata is False
|
||||
|
||||
assert env.io.settings.write_sys_log_to_terminal is False # false by default
|
||||
@@ -507,7 +507,7 @@ def test_firewall_acl_add_remove_rule_integration():
|
||||
with open(FIREWALL_ACTIONS_NETWORK, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
# 1: Check that traffic is normal and acl starts off with 4 rules.
|
||||
firewall = env.game.simulation.network.get_node_by_hostname("firewall")
|
||||
@@ -598,7 +598,7 @@ def test_firewall_port_disable_enable_integration():
|
||||
with open(FIREWALL_ACTIONS_NETWORK, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
firewall = env.game.simulation.network.get_node_by_hostname("firewall")
|
||||
|
||||
assert firewall.dmz_port.enabled == True
|
||||
|
||||
@@ -103,7 +103,7 @@ def test_shared_reward():
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
env.reset()
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ from primaite.game.agent.observations.nic_observations import NICObservation
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.nmne import set_nmne_config
|
||||
from primaite.simulator.sim_container import Simulation
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
|
||||
|
||||
def test_capture_nmne(uc2_network):
|
||||
@@ -15,7 +15,7 @@ def test_capture_nmne(uc2_network):
|
||||
"""
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa
|
||||
db_client.connect()
|
||||
db_client_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa
|
||||
|
||||
@@ -39,42 +39,42 @@ def test_capture_nmne(uc2_network):
|
||||
assert db_server_nic.nmne == {}
|
||||
|
||||
# Perform a "SELECT" query
|
||||
db_client.query("SELECT")
|
||||
db_client_connection.query(sql="SELECT")
|
||||
|
||||
# Check that it does not trigger an MNE capture.
|
||||
assert web_server_nic.nmne == {}
|
||||
assert db_server_nic.nmne == {}
|
||||
|
||||
# Perform a "DELETE" query
|
||||
db_client.query("DELETE")
|
||||
db_client_connection.query(sql="DELETE")
|
||||
|
||||
# Check that the web server's outbound interface and the database server's inbound interface register the MNE
|
||||
assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 1}}}}
|
||||
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 1}}}}
|
||||
|
||||
# Perform another "SELECT" query
|
||||
db_client.query("SELECT")
|
||||
db_client_connection.query(sql="SELECT")
|
||||
|
||||
# Check that no additional MNEs are captured
|
||||
assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 1}}}}
|
||||
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 1}}}}
|
||||
|
||||
# Perform another "DELETE" query
|
||||
db_client.query("DELETE")
|
||||
db_client_connection.query(sql="DELETE")
|
||||
|
||||
# Check that the web server and database server interfaces register an additional MNE
|
||||
assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 2}}}}
|
||||
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 2}}}}
|
||||
|
||||
# Perform an "ENCRYPT" query
|
||||
db_client.query("ENCRYPT")
|
||||
db_client_connection.query(sql="ENCRYPT")
|
||||
|
||||
# Check that the web server and database server interfaces register an additional MNE
|
||||
assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 3}}}}
|
||||
assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 3}}}}
|
||||
|
||||
# Perform another "SELECT" query
|
||||
db_client.query("SELECT")
|
||||
db_client_connection.query(sql="SELECT")
|
||||
|
||||
# Check that no additional MNEs are captured
|
||||
assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 3}}}}
|
||||
@@ -92,7 +92,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
"""
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa
|
||||
db_client.connect()
|
||||
db_client_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa
|
||||
|
||||
@@ -119,7 +119,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {}
|
||||
|
||||
# Perform a "SELECT" query
|
||||
db_client.query("SELECT")
|
||||
db_client_connection.query(sql="SELECT")
|
||||
|
||||
# Check that it does not trigger an MNE capture.
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -129,7 +129,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {}
|
||||
|
||||
# Perform a "DELETE" query
|
||||
db_client.query("DELETE")
|
||||
db_client_connection.query(sql="DELETE")
|
||||
|
||||
# Check that the web server's outbound interface and the database server's inbound interface register the MNE
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -139,7 +139,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}}
|
||||
|
||||
# Perform another "SELECT" query
|
||||
db_client.query("SELECT")
|
||||
db_client_connection.query(sql="SELECT")
|
||||
|
||||
# Check that no additional MNEs are captured
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -149,7 +149,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}}
|
||||
|
||||
# Perform another "DELETE" query
|
||||
db_client.query("DELETE")
|
||||
db_client_connection.query(sql="DELETE")
|
||||
|
||||
# Check that the web server and database server interfaces register an additional MNE
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -159,7 +159,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 2}}}}
|
||||
|
||||
# Perform a "ENCRYPT" query
|
||||
db_client.query("ENCRYPT")
|
||||
db_client_connection.query(sql="ENCRYPT")
|
||||
|
||||
# Check that the web server's outbound interface and the database server's inbound interface register the MNE
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -169,7 +169,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 3}}}}
|
||||
|
||||
# Perform another "SELECT" query
|
||||
db_client.query("SELECT")
|
||||
db_client_connection.query(sql="SELECT")
|
||||
|
||||
# Check that no additional MNEs are captured
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -179,7 +179,7 @@ def test_describe_state_nmne(uc2_network):
|
||||
assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 3}}}}
|
||||
|
||||
# Perform another "ENCRYPT"
|
||||
db_client.query("ENCRYPT")
|
||||
db_client_connection.query(sql="ENCRYPT")
|
||||
|
||||
# Check that the web server and database server interfaces register an additional MNE
|
||||
web_server_nic_state = web_server_nic.describe_state()
|
||||
@@ -206,7 +206,7 @@ def test_capture_nmne_observations(uc2_network):
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client.connect()
|
||||
db_client_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
# Set the NMNE configuration to capture DELETE/ENCRYPT queries as MNEs
|
||||
nmne_config = {
|
||||
@@ -228,7 +228,7 @@ def test_capture_nmne_observations(uc2_network):
|
||||
for i in range(0, 20):
|
||||
# Perform a "DELETE" query each iteration
|
||||
for j in range(i):
|
||||
db_client.query("DELETE")
|
||||
db_client_connection.query(sql="DELETE")
|
||||
|
||||
# Observe the current state of NMNEs from the NICs of both the database and web servers
|
||||
state = sim.describe_state()
|
||||
@@ -253,7 +253,7 @@ def test_capture_nmne_observations(uc2_network):
|
||||
for i in range(0, 20):
|
||||
# Perform a "ENCRYPT" query each iteration
|
||||
for j in range(i):
|
||||
db_client.query("ENCRYPT")
|
||||
db_client_connection.query(sql="ENCRYPT")
|
||||
|
||||
# Observe the current state of NMNEs from the NICs of both the database and web servers
|
||||
state = sim.describe_state()
|
||||
|
||||
@@ -10,7 +10,7 @@ from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.simulator.system.applications.red_applications.data_manipulation_bot import (
|
||||
DataManipulationAttackStage,
|
||||
DataManipulationBot,
|
||||
@@ -141,8 +141,10 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_
|
||||
server: Server = network.get_node_by_hostname("server_1")
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
|
||||
green_db_connection: DatabaseClientConnection = green_db_client.get_new_connection()
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
|
||||
assert green_db_client.query("SELECT")
|
||||
assert green_db_connection.query("SELECT")
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
|
||||
data_manipulation_bot.port_scan_p_of_success = 1
|
||||
@@ -151,5 +153,5 @@ def test_data_manipulation_disrupts_green_agent_connection(data_manipulation_db_
|
||||
data_manipulation_bot.attack()
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.COMPROMISED
|
||||
assert green_db_client.query("SELECT") is False
|
||||
assert green_db_connection.query("SELECT") is False
|
||||
assert green_db_client.last_query_response.get("status_code") != 200
|
||||
|
||||
@@ -10,7 +10,7 @@ from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import ApplicationOperatingState
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.simulator.system.applications.red_applications.ransomware_script import (
|
||||
RansomwareAttackStage,
|
||||
RansomwareScript,
|
||||
@@ -144,12 +144,13 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_
|
||||
|
||||
client_2: Computer = network.get_node_by_hostname("client_2")
|
||||
green_db_client: DatabaseClient = client_2.software_manager.software.get("DatabaseClient")
|
||||
green_db_client_connection: DatabaseClientConnection = green_db_client.get_new_connection()
|
||||
|
||||
server: Server = network.get_node_by_hostname("server_1")
|
||||
db_server_service: DatabaseService = server.software_manager.software.get("DatabaseService")
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.GOOD
|
||||
assert green_db_client.query("SELECT")
|
||||
assert green_db_client_connection.query("SELECT")
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
|
||||
ransomware_script_application.target_scan_p_of_success = 1
|
||||
@@ -159,5 +160,5 @@ def test_ransomware_disrupts_green_agent_connection(ransomware_script_db_server_
|
||||
ransomware_script_application.attack()
|
||||
|
||||
assert db_server_service.db_file.health_status is FileSystemItemHealthStatus.CORRUPT
|
||||
assert green_db_client.query("SELECT") is True
|
||||
assert green_db_client_connection.query("SELECT") is True
|
||||
assert green_db_client.last_query_response.get("status_code") == 200
|
||||
|
||||
@@ -8,7 +8,8 @@ from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Server
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient
|
||||
from primaite.simulator.network.hardware.nodes.network.switch import Switch
|
||||
from primaite.simulator.system.applications.database_client import DatabaseClient, DatabaseClientConnection
|
||||
from primaite.simulator.system.services.database.database_service import DatabaseService
|
||||
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
@@ -56,11 +57,12 @@ def test_database_client_server_connection(peer_to_peer):
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"]
|
||||
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 1
|
||||
|
||||
assert len(db_client.client_connections) == 1
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
db_client.disconnect()
|
||||
assert len(db_client.connections) == 0
|
||||
assert len(db_client.client_connections) == 0
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
|
||||
@@ -73,7 +75,7 @@ def test_database_client_server_correct_password(peer_to_peer_secure_db):
|
||||
|
||||
db_client.configure(server_ip_address=IPv4Address("192.168.0.11"), server_password="12345")
|
||||
db_client.connect()
|
||||
assert len(db_client.connections) == 1
|
||||
assert len(db_client.client_connections) == 1
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
|
||||
@@ -95,14 +97,24 @@ def test_database_client_server_incorrect_password(peer_to_peer_secure_db):
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
|
||||
def test_database_client_query(uc2_network):
|
||||
def test_database_client_native_connection_query(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 200 and date."""
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
db_client.connect()
|
||||
|
||||
assert db_client.query("SELECT")
|
||||
assert db_client.query("INSERT")
|
||||
assert db_client.query(sql="SELECT")
|
||||
assert db_client.query(sql="INSERT")
|
||||
|
||||
|
||||
def test_database_client_connection_query(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 200 and date."""
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
assert db_connection.query(sql="SELECT")
|
||||
assert db_connection.query(sql="INSERT")
|
||||
|
||||
|
||||
def test_create_database_backup(uc2_network):
|
||||
@@ -172,7 +184,6 @@ def test_restore_backup_after_deleting_file_without_updating_scan(uc2_network):
|
||||
db_server: Server = uc2_network.get_node_by_hostname("database_server")
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
|
||||
|
||||
# create a back up
|
||||
assert db_service.backup_database() is True
|
||||
|
||||
db_service.db_file.corrupt() # corrupt the db
|
||||
@@ -211,10 +222,13 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
assert len(db_client.connections)
|
||||
assert len(db_client.client_connections)
|
||||
|
||||
assert db_client.query("SELECT") is True
|
||||
assert db_client.query("INSERT") is True
|
||||
# Establish a new connection to the DatabaseService
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
assert db_connection.query("SELECT") is True
|
||||
assert db_connection.query("INSERT") is True
|
||||
db_server.power_off()
|
||||
|
||||
for i in range(db_server.shut_down_duration + 1):
|
||||
@@ -223,5 +237,121 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
|
||||
assert db_server.operating_state is NodeOperatingState.OFF
|
||||
assert db_service.operating_state is ServiceOperatingState.STOPPED
|
||||
|
||||
assert db_client.query("SELECT") is False
|
||||
assert db_client.query("INSERT") is False
|
||||
assert db_connection.query("SELECT") is False
|
||||
assert db_connection.query("INSERT") is False
|
||||
|
||||
|
||||
def test_database_client_uninstall_terminates_connections(peer_to_peer):
|
||||
node_a, node_b = peer_to_peer
|
||||
|
||||
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
|
||||
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
# Check that all connection counters are correct and that the client connection can query the database
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
assert len(db_client.client_connections) == 1
|
||||
|
||||
assert db_connection.is_active
|
||||
|
||||
assert db_connection.query("SELECT")
|
||||
|
||||
# Perform the DatabaseClient uninstall
|
||||
node_a.software_manager.uninstall("DatabaseClient")
|
||||
|
||||
# Check that all connection counters are updated accordingly and client connection can no longer query the database
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
assert len(db_client.client_connections) == 0
|
||||
|
||||
assert not db_connection.query("SELECT")
|
||||
|
||||
assert not db_connection.is_active
|
||||
|
||||
|
||||
def test_database_service_can_terminate_connection(peer_to_peer):
|
||||
node_a, node_b = peer_to_peer
|
||||
|
||||
db_client: DatabaseClient = node_a.software_manager.software["DatabaseClient"]
|
||||
db_service: DatabaseService = node_b.software_manager.software["DatabaseService"] # noqa
|
||||
|
||||
db_connection: DatabaseClientConnection = db_client.get_new_connection()
|
||||
|
||||
# Check that all connection counters are correct and that the client connection can query the database
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
assert len(db_client.client_connections) == 1
|
||||
|
||||
assert db_connection.is_active
|
||||
|
||||
assert db_connection.query("SELECT")
|
||||
|
||||
# Perform the server-led connection termination
|
||||
connection_id = next(iter(db_service.connections.keys()))
|
||||
db_service.terminate_connection(connection_id)
|
||||
|
||||
# Check that all connection counters are updated accordingly and client connection can no longer query the database
|
||||
assert len(db_service.connections) == 0
|
||||
|
||||
assert len(db_client.client_connections) == 0
|
||||
|
||||
assert not db_connection.query("SELECT")
|
||||
|
||||
assert not db_connection.is_active
|
||||
|
||||
|
||||
def test_client_connection_terminate_does_not_terminate_another_clients_connection():
|
||||
network = Network()
|
||||
|
||||
db_server = Server(
|
||||
hostname="db_client", ip_address="192.168.0.11", subnet_mask="255.255.255.0", start_up_duration=0
|
||||
)
|
||||
db_server.power_on()
|
||||
|
||||
db_server.software_manager.install(DatabaseService)
|
||||
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"] # noqa
|
||||
db_service.start()
|
||||
|
||||
client_a = Computer(
|
||||
hostname="client_a", ip_address="192.168.0.12", subnet_mask="255.255.255.0", start_up_duration=0
|
||||
)
|
||||
client_a.power_on()
|
||||
|
||||
client_a.software_manager.install(DatabaseClient)
|
||||
client_a.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
|
||||
client_a.software_manager.software["DatabaseClient"].run()
|
||||
|
||||
client_b = Computer(
|
||||
hostname="client_b", ip_address="192.168.0.13", subnet_mask="255.255.255.0", start_up_duration=0
|
||||
)
|
||||
client_b.power_on()
|
||||
|
||||
client_b.software_manager.install(DatabaseClient)
|
||||
client_b.software_manager.software["DatabaseClient"].configure(server_ip_address=IPv4Address("192.168.0.11"))
|
||||
client_b.software_manager.software["DatabaseClient"].run()
|
||||
|
||||
switch = Switch(hostname="switch", start_up_duration=0, num_ports=3)
|
||||
switch.power_on()
|
||||
|
||||
network.connect(endpoint_a=switch.network_interface[1], endpoint_b=db_server.network_interface[1])
|
||||
network.connect(endpoint_a=switch.network_interface[2], endpoint_b=client_a.network_interface[1])
|
||||
network.connect(endpoint_a=switch.network_interface[3], endpoint_b=client_b.network_interface[1])
|
||||
|
||||
db_client_a: DatabaseClient = client_a.software_manager.software["DatabaseClient"] # noqa
|
||||
db_connection_a = db_client_a.get_new_connection()
|
||||
|
||||
assert db_connection_a.query("SELECT")
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
db_client_b: DatabaseClient = client_b.software_manager.software["DatabaseClient"] # noqa
|
||||
db_connection_b = db_client_b.get_new_connection()
|
||||
|
||||
assert db_connection_b.query("SELECT")
|
||||
assert len(db_service.connections) == 2
|
||||
|
||||
db_connection_a.disconnect()
|
||||
|
||||
assert db_connection_b.query("SELECT")
|
||||
assert len(db_service.connections) == 1
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
import pytest
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router
|
||||
@@ -97,7 +98,7 @@ def test_request_fails_if_node_off(example_network, node_request):
|
||||
class TestDataManipulationGreenRequests:
|
||||
def test_node_off(self, uc2_network):
|
||||
"""Test that green requests succeed when the node is on and fail if the node is off."""
|
||||
net = uc2_network
|
||||
net: Network = uc2_network
|
||||
|
||||
client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"])
|
||||
client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"])
|
||||
@@ -131,7 +132,7 @@ class TestDataManipulationGreenRequests:
|
||||
|
||||
def test_acl_block(self, uc2_network):
|
||||
"""Test that green requests succeed when not blocked by ACLs but fail when blocked."""
|
||||
net = uc2_network
|
||||
net: Network = uc2_network
|
||||
|
||||
router: Router = net.get_node_by_hostname("router_1")
|
||||
client_1: HostNode = net.get_node_by_hostname("client_1")
|
||||
|
||||
0
tests/unit_tests/_primaite/_session/__init__.py
Normal file
0
tests/unit_tests/_primaite/_session/__init__.py
Normal file
50
tests/unit_tests/_primaite/_session/test_episode_schedule.py
Normal file
50
tests/unit_tests/_primaite/_session/test_episode_schedule.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.session.episode_schedule import ConstantEpisodeScheduler, EpisodeListScheduler
|
||||
|
||||
|
||||
def test_episode_list_scheduler():
|
||||
# Initialize an instance of EpisodeListScheduler
|
||||
|
||||
# Define a schedule and episode data for testing
|
||||
schedule = {0: ["episode1"], 1: ["episode2"]}
|
||||
episode_data = {"episode1": "data1: 1", "episode2": "data2: 2"}
|
||||
base_scenario = """agents: []"""
|
||||
|
||||
scheduler = EpisodeListScheduler(schedule=schedule, episode_data=episode_data, base_scenario=base_scenario)
|
||||
# Test when episode number is within the schedule
|
||||
result = scheduler(0)
|
||||
assert isinstance(result, dict)
|
||||
assert yaml.safe_load("data1: 1\nagents: []") == result
|
||||
|
||||
# Test next episode
|
||||
result = scheduler(1)
|
||||
assert isinstance(result, dict)
|
||||
assert yaml.safe_load("data2: 2\nagents: []") == result
|
||||
|
||||
# Test when episode number exceeds the schedule
|
||||
result = scheduler(2)
|
||||
assert isinstance(result, dict)
|
||||
assert yaml.safe_load("data1: 1\nagents: []") == result
|
||||
assert scheduler._exceeded_episode_list
|
||||
|
||||
# Test when episode number is a sequence
|
||||
scheduler.schedule = {0: ["episode1", "episode2"]}
|
||||
result = scheduler(0)
|
||||
assert isinstance(result, dict)
|
||||
assert yaml.safe_load("data1: 1\ndata2: 2\nagents: []") == result
|
||||
|
||||
|
||||
def test_constant_episode_scheduler():
|
||||
# Initialize an instance of ConstantEpisodeScheduler
|
||||
config = {"key": "value"}
|
||||
scheduler = ConstantEpisodeScheduler(config=config)
|
||||
|
||||
result = scheduler(0)
|
||||
assert isinstance(result, dict)
|
||||
assert {"key": "value"} == result
|
||||
|
||||
result = scheduler(1)
|
||||
assert isinstance(result, dict)
|
||||
assert {"key": "value"} == result
|
||||
@@ -70,7 +70,7 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot):
|
||||
dm_bot._perform_data_manipulation(p_of_success=1.0)
|
||||
|
||||
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
|
||||
assert len(dm_bot._host_db_client.connections)
|
||||
assert len(dm_bot._host_db_client.client_connections)
|
||||
|
||||
|
||||
def test_dm_bot_fails_without_db_client(dm_client):
|
||||
|
||||
@@ -56,7 +56,11 @@ def test_connect_to_database_fails_on_reattempt(database_client_on_computer):
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.connected = False
|
||||
assert database_client._connect(server_ip_address=IPv4Address("192.168.0.1"), is_reattempt=True) is False
|
||||
|
||||
database_connection = database_client._connect(
|
||||
server_ip_address=IPv4Address("192.168.0.1"), connection_request_id="", is_reattempt=True
|
||||
)
|
||||
assert database_connection is None
|
||||
|
||||
|
||||
def test_disconnect_when_client_is_closed(database_client_on_computer):
|
||||
@@ -79,21 +83,20 @@ def test_disconnect(database_client_on_computer):
|
||||
"""Database client should remove the connection."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
assert not database_client.connected
|
||||
assert database_client.connected is False
|
||||
|
||||
database_client.connect()
|
||||
|
||||
assert database_client.connected
|
||||
assert database_client.connected is True
|
||||
|
||||
database_client.disconnect()
|
||||
|
||||
assert not database_client.connected
|
||||
assert database_client.connected is False
|
||||
|
||||
|
||||
def test_query_when_client_is_closed(database_client_on_computer):
|
||||
"""Database client should return False when it is not running."""
|
||||
database_client, computer = database_client_on_computer
|
||||
|
||||
database_client.close()
|
||||
assert database_client.operating_state is ApplicationOperatingState.CLOSED
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.system.services.service import ServiceOperatingState
|
||||
from primaite.simulator.system.software import SoftwareHealthState
|
||||
|
||||
@@ -179,7 +181,8 @@ def test_overwhelm_service(service):
|
||||
assert service.health_state_actual is SoftwareHealthState.OVERWHELMED
|
||||
|
||||
|
||||
def test_create_and_remove_connections(service):
|
||||
@pytest.mark.xfail(reason="Fails as it's now too simple. Needs to be be refactored so that uses a service on a node.")
|
||||
def test_create_and_terminate_connections(service):
|
||||
service.start()
|
||||
uuid = str(uuid4())
|
||||
|
||||
@@ -187,6 +190,6 @@ def test_create_and_remove_connections(service):
|
||||
assert len(service.connections) == 1
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
assert service.remove_connection(connection_id=uuid) # should be true
|
||||
assert service.terminate_connection(connection_id=uuid) # should be true
|
||||
assert len(service.connections) == 0
|
||||
assert service.health_state_actual is SoftwareHealthState.GOOD
|
||||
|
||||
@@ -0,0 +1,126 @@
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite.simulator import LogLevel, SIM_OUTPUT
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def syslog() -> SysLog:
|
||||
return SysLog(hostname="test")
|
||||
|
||||
|
||||
def test_debug_sys_log_level(syslog, capsys):
|
||||
"""Test that the debug log level logs debug syslogs and above."""
|
||||
SIM_OUTPUT.sys_log_level = LogLevel.DEBUG
|
||||
SIM_OUTPUT.write_sys_log_to_terminal = True
|
||||
|
||||
test_string = str(uuid4())
|
||||
|
||||
syslog.debug(test_string)
|
||||
syslog.info(test_string)
|
||||
syslog.warning(test_string)
|
||||
syslog.error(test_string)
|
||||
syslog.critical(test_string)
|
||||
|
||||
captured = "".join(capsys.readouterr())
|
||||
|
||||
assert test_string in captured
|
||||
assert "DEBUG" in captured
|
||||
assert "INFO" in captured
|
||||
assert "WARNING" in captured
|
||||
assert "ERROR" in captured
|
||||
assert "CRITICAL" in captured
|
||||
|
||||
|
||||
def test_info_sys_log_level(syslog, capsys):
|
||||
"""Test that the debug log level logs debug syslogs and above."""
|
||||
SIM_OUTPUT.sys_log_level = LogLevel.INFO
|
||||
SIM_OUTPUT.write_sys_log_to_terminal = True
|
||||
|
||||
test_string = str(uuid4())
|
||||
|
||||
syslog.debug(test_string)
|
||||
syslog.info(test_string)
|
||||
syslog.warning(test_string)
|
||||
syslog.error(test_string)
|
||||
syslog.critical(test_string)
|
||||
|
||||
captured = "".join(capsys.readouterr())
|
||||
|
||||
assert test_string in captured
|
||||
assert "DEBUG" not in captured
|
||||
assert "INFO" in captured
|
||||
assert "WARNING" in captured
|
||||
assert "ERROR" in captured
|
||||
assert "CRITICAL" in captured
|
||||
|
||||
|
||||
def test_warning_sys_log_level(syslog, capsys):
|
||||
"""Test that the debug log level logs debug syslogs and above."""
|
||||
SIM_OUTPUT.sys_log_level = LogLevel.WARNING
|
||||
SIM_OUTPUT.write_sys_log_to_terminal = True
|
||||
|
||||
test_string = str(uuid4())
|
||||
|
||||
syslog.debug(test_string)
|
||||
syslog.info(test_string)
|
||||
syslog.warning(test_string)
|
||||
syslog.error(test_string)
|
||||
syslog.critical(test_string)
|
||||
|
||||
captured = "".join(capsys.readouterr())
|
||||
|
||||
assert test_string in captured
|
||||
assert "DEBUG" not in captured
|
||||
assert "INFO" not in captured
|
||||
assert "WARNING" in captured
|
||||
assert "ERROR" in captured
|
||||
assert "CRITICAL" in captured
|
||||
|
||||
|
||||
def test_error_sys_log_level(syslog, capsys):
|
||||
"""Test that the debug log level logs debug syslogs and above."""
|
||||
SIM_OUTPUT.sys_log_level = LogLevel.ERROR
|
||||
SIM_OUTPUT.write_sys_log_to_terminal = True
|
||||
|
||||
test_string = str(uuid4())
|
||||
|
||||
syslog.debug(test_string)
|
||||
syslog.info(test_string)
|
||||
syslog.warning(test_string)
|
||||
syslog.error(test_string)
|
||||
syslog.critical(test_string)
|
||||
|
||||
captured = "".join(capsys.readouterr())
|
||||
|
||||
assert test_string in captured
|
||||
assert "DEBUG" not in captured
|
||||
assert "INFO" not in captured
|
||||
assert "WARNING" not in captured
|
||||
assert "ERROR" in captured
|
||||
assert "CRITICAL" in captured
|
||||
|
||||
|
||||
def test_critical_sys_log_level(syslog, capsys):
|
||||
"""Test that the debug log level logs debug syslogs and above."""
|
||||
SIM_OUTPUT.sys_log_level = LogLevel.CRITICAL
|
||||
SIM_OUTPUT.write_sys_log_to_terminal = True
|
||||
|
||||
test_string = str(uuid4())
|
||||
|
||||
syslog.debug(test_string)
|
||||
syslog.info(test_string)
|
||||
syslog.warning(test_string)
|
||||
syslog.error(test_string)
|
||||
syslog.critical(test_string)
|
||||
|
||||
captured = "".join(capsys.readouterr())
|
||||
|
||||
assert test_string in captured
|
||||
assert "DEBUG" not in captured
|
||||
assert "INFO" not in captured
|
||||
assert "WARNING" not in captured
|
||||
assert "ERROR" not in captured
|
||||
assert "CRITICAL" in captured
|
||||
Reference in New Issue
Block a user