Merge branch 'dev' into feature/2457-Set_link_bandwidth_via_config

This commit is contained in:
Charlie Crane
2024-05-14 14:44:20 +01:00
44 changed files with 1059 additions and 354 deletions

View File

@@ -14,13 +14,13 @@ parameters:
- name: matrix - name: matrix
type: object type: object
default: default:
# - job_name: 'UbuntuPython38' - job_name: 'UbuntuPython38'
# py: '3.8' py: '3.8'
# img: 'ubuntu-latest' img: 'ubuntu-latest'
# every_time: false every_time: false
# publish_coverage: false publish_coverage: false
- job_name: 'UbuntuPython310' - job_name: 'UbuntuPython311'
py: '3.10' py: '3.11'
img: 'ubuntu-latest' img: 'ubuntu-latest'
every_time: true every_time: true
publish_coverage: true publish_coverage: true
@@ -29,8 +29,8 @@ parameters:
img: 'windows-latest' img: 'windows-latest'
every_time: false every_time: false
publish_coverage: false publish_coverage: false
- job_name: 'WindowsPython310' - job_name: 'WindowsPython311'
py: '3.10' py: '3.11'
img: 'windows-latest' img: 'windows-latest'
every_time: false every_time: false
publish_coverage: false publish_coverage: false
@@ -39,8 +39,8 @@ parameters:
img: 'macOS-latest' img: 'macOS-latest'
every_time: false every_time: false
publish_coverage: false publish_coverage: false
- job_name: 'MacOSPython310' - job_name: 'MacOSPython311'
py: '3.10' py: '3.11'
img: 'macOS-latest' img: 'macOS-latest'
every_time: false every_time: false
publish_coverage: false publish_coverage: false

2
.gitignore vendored
View File

@@ -82,6 +82,7 @@ target/
# Jupyter Notebook # Jupyter Notebook
.ipynb_checkpoints .ipynb_checkpoints
PPO_UC2/
# IPython # IPython
profile_default/ profile_default/
@@ -150,6 +151,7 @@ docs/source/primaite-dependencies.rst
# outputs # outputs
src/primaite/outputs/ src/primaite/outputs/
simulation_output/ simulation_output/
sessions/
# benchmark session outputs # benchmark session outputs
benchmark/output benchmark/output

View File

@@ -14,14 +14,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Added ability to define scenarios that change depending on the episode number. - Added ability to define scenarios that change depending on the episode number.
- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config` - Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config`
- Database Connection ID's are now created/issued by DatabaseService and not DatabaseClient - Database Connection ID's are now created/issued by DatabaseService and not DatabaseClient
- added ability to set PrimAITE between development and production modes via PrimAITE CLI ``mode`` command
- Updated DatabaseClient so that it can now have a single native DatabaseClientConnection along with a collection of DatabaseClientConnection's. - Updated DatabaseClient so that it can now have a single native DatabaseClientConnection along with a collection of DatabaseClientConnection's.
- Implemented the uninstall functionality for DatabaseClient so that all connections are terminated at the DatabaseService. - Implemented the uninstall functionality for DatabaseClient so that all connections are terminated at the DatabaseService.
- Added the ability for a DatabaseService to terminate a connection. - Added the ability for a DatabaseService to terminate a connection.
- Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used. - Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used.
- Added additional show functions to enable connection inspection. - Added additional show functions to enable connection inspection.
- Updates to agent logging, to include the reward both per step and per episode. - Updates to agent logging, to include the reward both per step and per episode.
- Introduced Developer CLI tools to assist with developing/debugging PrimAITE
- Can be enabled via `primaite dev-mode enable`
- Activating dev-mode will change the location where the sessions will be output - by default will output where the PrimAITE repository is located
- Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization. - Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization.
- Added notebook to demonstrate use of SubprocVecEnv from SB3 to vectorise environments to speed up training.
## [Unreleased] ## [Unreleased]

View File

@@ -116,6 +116,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
:caption: Developer information: :caption: Developer information:
:hidden: :hidden:
source/developer_tools
source/state_system source/state_system
source/request_system source/request_system
PrimAITE API <source/_autosummary/primaite> PrimAITE API <source/_autosummary/primaite>

View File

@@ -0,0 +1,210 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _Developer Tools:
Developer Tools
***************
PrimAITE includes developer CLI tools that are intended to be used by developers.
dev-mode
========
The dev-mode contains configuration which override any of the config files during runtime.
This is intended to make debugging easier by removing the need to find the relevant configuration file/settings.
Enabling dev-mode
-----------------
The PrimAITE dev-mode can be enabled via the use of
.. code-block::
primaite dev-mode enable
Disabling dev-mode
------------------
The PrimAITE dev-mode can be disabled via the use of
.. code-block::
primaite dev-mode disable
Show current mode
-----------------
To show if the dev-mode is enabled or not, use
The PrimAITE dev-mode can be disabled via the use of
.. code-block::
primaite dev-mode show
dev-mode configuration
======================
The following configures some specific items that the dev-mode overrides, if enabled.
`--sys-log-level` or `-level`
----------------------------
The level of system logs can be overridden by dev-mode.
By default, this is set to DEBUG
The available options are [DEBUG|INFO|WARNING|ERROR|CRITICAL]
.. code-block::
primaite dev-mode config -level INFO
or
.. code-block::
primaite dev-mode config --sys-log-level INFO
`--output-sys-logs` or `-sys`
-----------------------------
The outputting of system logs can be overridden by dev-mode.
By default, this is set to False
Enabling system logs
""""""""""""""""""""
To enable outputting of system logs
.. code-block::
primaite dev-mode config --output-sys-logs
or
.. code-block::
primaite dev-mode config -sys
Disabling system logs
"""""""""""""""""""""
To disable outputting of system logs
.. code-block::
primaite dev-mode config --no-sys-logs
or
.. code-block::
primaite dev-mode config -nsys
`--output-pcap-logs` or `-pcap`
-------------------------------
The outputting of packet capture logs can be overridden by dev-mode.
By default, this is set to False
Enabling PCAP logs
""""""""""""""""""
To enable outputting of packet capture logs
.. code-block::
primaite dev-mode config --output-pcap-logs
or
.. code-block::
primaite dev-mode config -pcap
Disabling PCAP logs
"""""""""""""""""""
To disable outputting of packet capture logs
.. code-block::
primaite dev-mode config --no-pcap-logs
or
.. code-block::
primaite dev-mode config -npcap
`--output-to-terminal` or `-t`
------------------------------
The outputting of system logs to the terminal can be overridden by dev-mode.
By default, this is set to False
Enabling system log output to terminal
""""""""""""""""""""""""""""""""""""""
To enable outputting of system logs to terminal
.. code-block::
primaite dev-mode config --output-to-terminal
or
.. code-block::
primaite dev-mode config -t
Disabling system log output to terminal
"""""""""""""""""""""""""""""""""""""""
To disable outputting of system logs to terminal
.. code-block::
primaite dev-mode config --no-terminal
or
.. code-block::
primaite dev-mode config -nt
path
----
PrimAITE dev-mode can override where sessions are output.
By default, PrimAITE will output the sessions in USER_HOME/primaite/sessions
With dev-mode enabled, by default, this will be changed to PRIMAITE_REPOSITORY_ROOT/sessions
However, providing a path will let dev-mode output sessions to the given path e.g.
.. code-block:: bash
:caption: Unix
primaite dev-mode config path ~/output/path
.. code-block:: powershell
:caption: Windows (Powershell)
primaite dev-mode config path ~\output\path
default path
""""""""""""
To reset the path to use the PRIMAITE_REPOSITORY_ROOT/sessions, run the command
.. code-block::
primaite dev-mode config path --default

View File

@@ -161,9 +161,11 @@ To set PrimAITE to run in development mode:
.. code-block:: bash .. code-block:: bash
:caption: Unix :caption: Unix
primaite mode --dev primaite dev-mode enable
.. code-block:: powershell .. code-block:: powershell
:caption: Windows (Powershell) :caption: Windows (Powershell)
primaite mode --dev primaite dev-mode enable
More information about :ref:`Developer Tools`

View File

@@ -7,7 +7,7 @@ name = "primaite"
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme." description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}] authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}]
license = {file = "LICENSE"} license = {file = "LICENSE"}
requires-python = ">=3.8, <3.11" requires-python = ">=3.8, <3.12"
dynamic = ["version", "readme"] dynamic = ["version", "readme"]
classifiers = [ classifiers = [
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
@@ -20,6 +20,7 @@ classifiers = [
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3 :: Only",
] ]

View File

@@ -1 +1 @@
3.0.0b9dev 3.0.0b9

View File

@@ -122,35 +122,20 @@ class _PrimaitePaths:
PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths() PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths()
def _host_primaite_config() -> None:
if not PRIMAITE_PATHS.app_config_file_path.exists():
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
_host_primaite_config()
def _get_primaite_config() -> Dict: def _get_primaite_config() -> Dict:
config_path = PRIMAITE_PATHS.app_config_file_path config_path = PRIMAITE_PATHS.app_config_file_path
if not config_path.exists(): if not config_path.exists():
# load from package if config does not exist
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml")) config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
# generate app config
shutil.copy2(config_path, PRIMAITE_PATHS.app_config_file_path)
with open(config_path, "r") as file: with open(config_path, "r") as file:
# load from config
primaite_config = yaml.safe_load(file) primaite_config = yaml.safe_load(file)
log_level_map = { return primaite_config
"NOTSET": logging.NOTSET,
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARN": logging.WARN,
"WARNING": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
return primaite_config
_PRIMAITE_CONFIG = _get_primaite_config() PRIMAITE_CONFIG = _get_primaite_config()
class _LevelFormatter(Formatter): class _LevelFormatter(Formatter):
@@ -177,11 +162,11 @@ class _LevelFormatter(Formatter):
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter( _LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
{ {
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"], logging.DEBUG: PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"], logging.INFO: PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"], logging.WARNING: PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"], logging.ERROR: PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"], logging.CRITICAL: PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
} }
) )
@@ -193,10 +178,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
backupCount=9, # Max 100MB of logs backupCount=9, # Max 100MB of logs
encoding="utf8", encoding="utf8",
) )
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) _STREAM_HANDLER.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"]) _FILE_HANDLER.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"] _LOG_FORMAT_STR: Final[str] = PRIMAITE_CONFIG["logging"]["logger_format"]
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER) _STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER) _FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
@@ -215,6 +200,6 @@ def getLogger(name: str) -> Logger: # noqa
logging config. logging config.
""" """
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(_PRIMAITE_CONFIG["log_level"]) logger.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
return logger return logger

View File

@@ -2,16 +2,21 @@
"""Provides a CLI using Typer as an entry point.""" """Provides a CLI using Typer as an entry point."""
import logging import logging
import os import os
import shutil
from enum import Enum from enum import Enum
from pathlib import Path
from typing import Optional from typing import Optional
import pkg_resources
import typer import typer
import yaml import yaml
from typing_extensions import Annotated from typing_extensions import Annotated
from primaite import PRIMAITE_PATHS from primaite import PRIMAITE_PATHS
from primaite.utils.cli import dev_cli
app = typer.Typer(no_args_is_help=True) app = typer.Typer(no_args_is_help=True)
app.add_typer(dev_cli.dev, name="dev-mode")
@app.command() @app.command()
@@ -89,7 +94,7 @@ def version() -> None:
@app.command() @app.command()
def setup(overwrite_existing: bool = True) -> None: def setup(overwrite_existing: bool = False) -> None:
""" """
Perform the PrimAITE first-time setup. Perform the PrimAITE first-time setup.
@@ -102,11 +107,14 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Performing the PrimAITE first-time setup...") _LOGGER.info("Performing the PrimAITE first-time setup...")
_LOGGER.info("Building primaite_config.yaml...")
_LOGGER.info("Building the PrimAITE app directories...") _LOGGER.info("Building the PrimAITE app directories...")
PRIMAITE_PATHS.mkdirs() PRIMAITE_PATHS.mkdirs()
_LOGGER.info("Building primaite_config.yaml...")
if overwrite_existing:
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
shutil.copy(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
_LOGGER.info("Rebuilding the demo notebooks...") _LOGGER.info("Rebuilding the demo notebooks...")
reset_demo_notebooks.run(overwrite_existing=True) reset_demo_notebooks.run(overwrite_existing=True)
@@ -114,47 +122,3 @@ def setup(overwrite_existing: bool = True) -> None:
reset_example_configs.run(overwrite_existing=True) reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("PrimAITE setup complete!") _LOGGER.info("PrimAITE setup complete!")
@app.command()
def mode(
dev: Annotated[bool, typer.Option("--dev", help="Activates PrimAITE developer mode")] = None,
prod: Annotated[bool, typer.Option("--prod", help="Activates PrimAITE production mode")] = None,
) -> None:
"""
Switch PrimAITE between developer mode and production mode.
By default, PrimAITE will be in production mode.
To view the current mode, use: primaite mode
To set to development mode, use: primaite mode --dev
To return to production mode, use: primaite mode --prod
"""
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
if dev and prod:
print("Unable to activate developer and production modes concurrently.")
return
if (dev is None) and (prod is None):
is_dev_mode = primaite_config["developer_mode"]
if is_dev_mode:
print("PrimAITE is running in developer mode.")
else:
print("PrimAITE is running in production mode.")
if dev:
# activate dev mode
primaite_config["developer_mode"] = True
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print("PrimAITE is running in developer mode.")
if prod:
# activate prod mode
primaite_config["developer_mode"] = False
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
yaml.dump(primaite_config, file)
print("PrimAITE is running in production mode.")

View File

@@ -45,15 +45,9 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"print(cfg['agents'][2]['agent_settings'])" "for agent in cfg['agents']:\n",
] " if agent[\"ref\"] == \"defender\":\n",
}, " agent['agent_settings']['flatten_obs'] = True\n",
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_config = cfg\n", "env_config = cfg\n",
"\n", "\n",
"config = (\n", "config = (\n",
@@ -80,7 +74,7 @@
"tune.Tuner(\n", "tune.Tuner(\n",
" \"PPO\",\n", " \"PPO\",\n",
" run_config=air.RunConfig(\n", " run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 512}\n", " stop={\"timesteps_total\": 1e3 * 128}\n",
" ),\n", " ),\n",
" param_space=config\n", " param_space=config\n",
").fit()\n" ").fit()\n"

View File

@@ -0,0 +1,148 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple multi-processing demo using SubprocVecEnv from SB3\n",
"Based on a code example provided by Rachael Proctor."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Import packages and read config file."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"from stable_baselines3 import PPO\n",
"from stable_baselines3.common.utils import set_random_seed\n",
"from stable_baselines3.common.vec_env import SubprocVecEnv\n",
"\n",
"from primaite.session.environment import PrimaiteGymEnv\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.config.load import data_manipulation_config_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up training data."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"EPISODE_LEN = 128\n",
"NUM_EPISODES = 10\n",
"NO_STEPS = EPISODE_LEN * NUM_EPISODES\n",
"BATCH_SIZE = 32\n",
"LEARNING_RATE = 3e-4\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define an environment function."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"\n",
"def make_env(rank: int, seed: int = 0) -> callable:\n",
" \"\"\"Wrapper script for _init function.\"\"\"\n",
"\n",
" def _init() -> PrimaiteGymEnv:\n",
" env = PrimaiteGymEnv(env_config=cfg)\n",
" env.reset(seed=seed + rank)\n",
" model = PPO(\n",
" \"MlpPolicy\",\n",
" env,\n",
" learning_rate=LEARNING_RATE,\n",
" n_steps=NO_STEPS,\n",
" batch_size=BATCH_SIZE,\n",
" verbose=0,\n",
" tensorboard_log=\"./PPO_UC2/\",\n",
" )\n",
" model.learn(total_timesteps=NO_STEPS)\n",
" return env\n",
"\n",
" set_random_seed(seed)\n",
" return _init\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run experiment."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"n_procs = 2\n",
"train_env = SubprocVecEnv([make_env(i + n_procs) for i in range(n_procs)])\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -5,9 +5,9 @@ from typing import Dict, List, Optional
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from primaite import getLogger, PRIMAITE_PATHS from primaite import _PRIMAITE_ROOT, getLogger, PRIMAITE_CONFIG, PRIMAITE_PATHS
from primaite.simulator import LogLevel, SIM_OUTPUT from primaite.simulator import LogLevel, SIM_OUTPUT
from primaite.utils.primaite_config_utils import is_dev_mode from primaite.utils.cli.primaite_config_utils import is_dev_mode
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
@@ -62,12 +62,15 @@ class PrimaiteIO:
date_str = timestamp.strftime("%Y-%m-%d") date_str = timestamp.strftime("%Y-%m-%d")
time_str = timestamp.strftime("%H-%M-%S") time_str = timestamp.strftime("%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
# check if running in dev mode # check if running in dev mode
if is_dev_mode(): if is_dev_mode():
# if dev mode, simulation output will be the current working directory session_path = _PRIMAITE_ROOT.parent.parent / "sessions" / date_str / time_str
session_path = Path.cwd() / "simulation_output" / date_str / time_str
else: # check if there is an output directory set in config
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
session_path = Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"]) / "sessions" / date_str / time_str
session_path.mkdir(exist_ok=True, parents=True) session_path.mkdir(exist_ok=True, parents=True)
return session_path return session_path

View File

@@ -1,6 +1,12 @@
# The main PrimAITE application config file # The main PrimAITE application config file
developer_mode: False # false by default developer_mode:
enabled: False # not enabled by default
sys_log_level: DEBUG # level of output for system logs, DEBUG by default
output_sys_logs: False # system logs not output by default
output_pcap_logs: False # pcap logs not output by default
output_to_terminal: False # do not output to terminal by default
output_dir: null # none by default - none will print to repository root
# Logging # Logging
logging: logging:

View File

@@ -3,10 +3,12 @@ from datetime import datetime
from enum import IntEnum from enum import IntEnum
from pathlib import Path from pathlib import Path
from primaite import _PRIMAITE_ROOT from primaite import _PRIMAITE_ROOT, PRIMAITE_CONFIG, PRIMAITE_PATHS
__all__ = ["SIM_OUTPUT"] __all__ = ["SIM_OUTPUT"]
from primaite.utils.cli.primaite_config_utils import is_dev_mode
class LogLevel(IntEnum): class LogLevel(IntEnum):
"""Enum containing all the available log levels for PrimAITE simulation output.""" """Enum containing all the available log levels for PrimAITE simulation output."""
@@ -25,16 +27,34 @@ class LogLevel(IntEnum):
class _SimOutput: class _SimOutput:
def __init__(self): def __init__(self):
self._path: Path = ( date_str = datetime.now().strftime("%Y-%m-%d")
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") time_str = datetime.now().strftime("%H-%M-%S")
)
self.save_pcap_logs: bool = False path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
self.save_sys_logs: bool = False
self.write_sys_log_to_terminal: bool = False self._path = path
self.sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING self._save_pcap_logs: bool = False
self._save_sys_logs: bool = False
self._write_sys_log_to_terminal: bool = False
self._sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
@property @property
def path(self) -> Path: def path(self) -> Path:
if is_dev_mode():
date_str = datetime.now().strftime("%Y-%m-%d")
time_str = datetime.now().strftime("%H-%M-%S")
# if dev mode is enabled, if output dir is not set, print to primaite repo root
path: Path = _PRIMAITE_ROOT.parent.parent / "sessions" / date_str / time_str / "simulation_output"
# otherwise print to output dir
if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
path: Path = (
Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"])
/ "sessions"
/ date_str
/ time_str
/ "simulation_output"
)
self._path = path
return self._path return self._path
@path.setter @path.setter
@@ -42,5 +62,45 @@ class _SimOutput:
self._path = new_path self._path = new_path
self._path.mkdir(exist_ok=True, parents=True) self._path.mkdir(exist_ok=True, parents=True)
@property
def save_pcap_logs(self) -> bool:
if is_dev_mode():
return PRIMAITE_CONFIG.get("developer_mode").get("output_pcap_logs")
return self._save_pcap_logs
@save_pcap_logs.setter
def save_pcap_logs(self, save_pcap_logs: bool) -> None:
self._save_pcap_logs = save_pcap_logs
@property
def save_sys_logs(self) -> bool:
if is_dev_mode():
return PRIMAITE_CONFIG.get("developer_mode").get("output_sys_logs")
return self._save_sys_logs
@save_sys_logs.setter
def save_sys_logs(self, save_sys_logs: bool) -> None:
self._save_sys_logs = save_sys_logs
@property
def write_sys_log_to_terminal(self) -> bool:
if is_dev_mode():
return PRIMAITE_CONFIG.get("developer_mode").get("output_to_terminal")
return self._write_sys_log_to_terminal
@write_sys_log_to_terminal.setter
def write_sys_log_to_terminal(self, write_sys_log_to_terminal: bool) -> None:
self._write_sys_log_to_terminal = write_sys_log_to_terminal
@property
def sys_log_level(self) -> LogLevel:
if is_dev_mode():
return LogLevel[PRIMAITE_CONFIG.get("developer_mode").get("sys_log_level")]
return self._sys_log_level
@sys_log_level.setter
def sys_log_level(self, sys_log_level: LogLevel) -> None:
self._sys_log_level = sys_log_level
SIM_OUTPUT = _SimOutput() SIM_OUTPUT = _SimOutput()

View File

@@ -261,7 +261,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.10.11" "version": "3.10.12"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@@ -256,9 +256,11 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "22", "id": "22",
"metadata": {}, "metadata": {
"tags": []
},
"source": [ "source": [
"Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`." "Calling `switch.arp.show()` displays the Switch ARP Cache."
] ]
}, },
{ {
@@ -270,13 +272,33 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"switch_1\").sys_log.show()" "network.get_node_by_hostname(\"switch_1\").arp.show()"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "24", "id": "24",
"metadata": {}, "metadata": {},
"source": [
"Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"network.get_node_by_hostname(\"switch_1\").sys_log.show()"
]
},
{
"cell_type": "markdown",
"id": "26",
"metadata": {},
"source": [ "source": [
"### Computer/Server Nodes\n", "### Computer/Server Nodes\n",
"\n", "\n",
@@ -285,7 +307,7 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "25", "id": "27",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@@ -293,26 +315,6 @@
"Calling `computer.show()` displays the NICs on the Computer/Server." "Calling `computer.show()` displays the NICs on the Computer/Server."
] ]
}, },
{
"cell_type": "code",
"execution_count": null,
"id": "26",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"network.get_node_by_hostname(\"security_suite\").show()"
]
},
{
"cell_type": "markdown",
"id": "27",
"metadata": {},
"source": [
"Calling `computer.arp.show()` displays the Computer/Server ARP Cache."
]
},
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
@@ -322,7 +324,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"security_suite\").arp.show()" "network.get_node_by_hostname(\"security_suite\").show()"
] ]
}, },
{ {
@@ -330,7 +332,7 @@
"id": "29", "id": "29",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Calling `computer.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`." "Calling `computer.arp.show()` displays the Computer/Server ARP Cache."
] ]
}, },
{ {
@@ -342,7 +344,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"security_suite\").sys_log.show()" "network.get_node_by_hostname(\"security_suite\").arp.show()"
] ]
}, },
{ {
@@ -350,9 +352,7 @@
"id": "31", "id": "31",
"metadata": {}, "metadata": {},
"source": [ "source": [
"## Basic Network Comms Check\n", "Calling `computer.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
"\n",
"We can perform a good old ping to check that Nodes are able to communicate with each other."
] ]
}, },
{ {
@@ -364,7 +364,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.show(nodes=False, links=False)" "network.get_node_by_hostname(\"security_suite\").sys_log.show()"
] ]
}, },
{ {
@@ -372,7 +372,9 @@
"id": "33", "id": "33",
"metadata": {}, "metadata": {},
"source": [ "source": [
"We'll first ping client_1's default gateway." "## Basic Network Comms Check\n",
"\n",
"We can perform a good old ping to check that Nodes are able to communicate with each other."
] ]
}, },
{ {
@@ -384,27 +386,27 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.10.1\")" "network.show(nodes=False, links=False)"
]
},
{
"cell_type": "markdown",
"id": "35",
"metadata": {},
"source": [
"We'll first ping client_1's default gateway."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "35", "id": "36",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"client_1\").sys_log.show(15)" "network.get_node_by_hostname(\"client_1\").ping(\"192.168.10.1\")"
]
},
{
"cell_type": "markdown",
"id": "36",
"metadata": {},
"source": [
"Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)."
] ]
}, },
{ {
@@ -416,7 +418,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.1\")" "network.get_node_by_hostname(\"client_1\").sys_log.show(15)"
] ]
}, },
{ {
@@ -424,7 +426,7 @@
"id": "38", "id": "38",
"metadata": {}, "metadata": {},
"source": [ "source": [
"And finally, we'll ping the web server." "Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)."
] ]
}, },
{ {
@@ -436,7 +438,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")" "network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.1\")"
] ]
}, },
{ {
@@ -444,7 +446,7 @@
"id": "40", "id": "40",
"metadata": {}, "metadata": {},
"source": [ "source": [
"To confirm that the ping was received and processed by the web_server, we can view the sys log" "And finally, we'll ping the web server."
] ]
}, },
{ {
@@ -456,45 +458,45 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"web_server\").sys_log.show()" "network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "42", "id": "42",
"metadata": {}, "metadata": {},
"source": [
"To confirm that the ping was received and processed by the web_server, we can view the sys log"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "43",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"network.get_node_by_hostname(\"web_server\").sys_log.show()"
]
},
{
"cell_type": "markdown",
"id": "44",
"metadata": {},
"source": [ "source": [
"## Advanced Network Usage\n", "## Advanced Network Usage\n",
"\n", "\n",
"We can now use the Network to perform some more advanced things." "We can now use the Network to perform some more advanced things."
] ]
}, },
{
"cell_type": "markdown",
"id": "43",
"metadata": {},
"source": [
"Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "44",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "45", "id": "45",
"metadata": {}, "metadata": {},
"source": [ "source": [
"If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:" "Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..."
] ]
}, },
{ {
@@ -506,13 +508,33 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"client_2\").sys_log.show()" "network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")"
] ]
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "47", "id": "47",
"metadata": {}, "metadata": {},
"source": [
"If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "48",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"network.get_node_by_hostname(\"client_2\").sys_log.show()"
]
},
{
"cell_type": "markdown",
"id": "49",
"metadata": {},
"source": [ "source": [
"Now we'll add an ACL to block ICMP from 192.168.10.22" "Now we'll add an ACL to block ICMP from 192.168.10.22"
] ]
@@ -520,7 +542,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "48", "id": "50",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@@ -540,7 +562,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "49", "id": "51",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },
@@ -549,32 +571,12 @@
"network.get_node_by_hostname(\"router_1\").acl.show()" "network.get_node_by_hostname(\"router_1\").acl.show()"
] ]
}, },
{
"cell_type": "markdown",
"id": "50",
"metadata": {},
"source": [
"Now we attempt (and fail) to ping the web server"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "51",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"id": "52", "id": "52",
"metadata": {}, "metadata": {},
"source": [ "source": [
"We can check that the ping was actually sent by client_2 by viewing the sys log" "Now we attempt (and fail) to ping the web server"
] ]
}, },
{ {
@@ -586,7 +588,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"client_2\").sys_log.show()" "network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")"
] ]
}, },
{ {
@@ -594,7 +596,7 @@
"id": "54", "id": "54",
"metadata": {}, "metadata": {},
"source": [ "source": [
"We can check the router sys log to see why the traffic was blocked" "We can check that the ping was actually sent by client_2 by viewing the sys log"
] ]
}, },
{ {
@@ -606,7 +608,7 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"router_1\").sys_log.show()" "network.get_node_by_hostname(\"client_2\").sys_log.show()"
] ]
}, },
{ {
@@ -614,7 +616,7 @@
"id": "56", "id": "56",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Now a final check to ensure that client_1 can still ping the web_server." "We can check the router sys log to see why the traffic was blocked"
] ]
}, },
{ {
@@ -625,6 +627,26 @@
"tags": [] "tags": []
}, },
"outputs": [], "outputs": [],
"source": [
"network.get_node_by_hostname(\"router_1\").sys_log.show()"
]
},
{
"cell_type": "markdown",
"id": "58",
"metadata": {},
"source": [
"Now a final check to ensure that client_1 can still ping the web_server."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "59",
"metadata": {
"tags": []
},
"outputs": [],
"source": [ "source": [
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")" "network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")"
] ]
@@ -632,7 +654,7 @@
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": null, "execution_count": null,
"id": "58", "id": "60",
"metadata": { "metadata": {
"tags": [] "tags": []
}, },

View File

@@ -2,9 +2,7 @@ from __future__ import annotations
import hashlib import hashlib
import json import json
import os.path
import warnings import warnings
from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional
from primaite import getLogger from primaite import getLogger
@@ -21,8 +19,6 @@ class File(FileSystemItemABC):
:ivar Folder folder: The folder in which the file resides. :ivar Folder folder: The folder in which the file resides.
:ivar FileType file_type: The type of the file. :ivar FileType file_type: The type of the file.
:ivar Optional[int] sim_size: The simulated file size. :ivar Optional[int] sim_size: The simulated file size.
:ivar bool real: Indicates if the file is actually a real file in the Node sim fs output.
:ivar Optional[Path] sim_path: The path if the file is real.
""" """
folder_id: str folder_id: str
@@ -33,12 +29,6 @@ class File(FileSystemItemABC):
"The type of File." "The type of File."
sim_size: Optional[int] = None sim_size: Optional[int] = None
"The simulated file size." "The simulated file size."
real: bool = False
"Indicates whether the File is actually a real file in the Node sim fs output."
sim_path: Optional[Path] = None
"The Path if real is True."
sim_root: Optional[Path] = None
"Root path of the simulation."
num_access: int = 0 num_access: int = 0
"Number of times the file was accessed in the current step." "Number of times the file was accessed in the current step."
@@ -67,13 +57,6 @@ class File(FileSystemItemABC):
if not kwargs.get("sim_size"): if not kwargs.get("sim_size"):
kwargs["sim_size"] = kwargs["file_type"].default_size kwargs["sim_size"] = kwargs["file_type"].default_size
super().__init__(**kwargs) super().__init__(**kwargs)
if self.real:
self.sim_path = self.sim_root / self.path
if not self.sim_path.exists():
self.sim_path.parent.mkdir(exist_ok=True, parents=True)
with open(self.sim_path, mode="a"):
pass
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
@property @property
@@ -92,8 +75,6 @@ class File(FileSystemItemABC):
:return: The size of the file in bytes. :return: The size of the file in bytes.
""" """
if self.real:
return os.path.getsize(self.sim_path)
return self.sim_size return self.sim_size
def apply_timestep(self, timestep: int) -> None: def apply_timestep(self, timestep: int) -> None:
@@ -127,7 +108,7 @@ class File(FileSystemItemABC):
self.num_access += 1 # file was accessed self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name path = self.folder.name + "/" + self.name
self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}") self.sys_log.info(f"Scanning file {path}")
self.visible_health_status = self.health_status self.visible_health_status = self.health_status
return True return True
@@ -155,17 +136,8 @@ class File(FileSystemItemABC):
return False return False
current_hash = None current_hash = None
# if file is real, read the file contents # otherwise get describe_state dict and hash that
if self.real: current_hash = hashlib.blake2b(json.dumps(self.describe_state(), sort_keys=True).encode()).hexdigest()
with open(self.sim_path, "rb") as f:
file_hash = hashlib.blake2b()
while chunk := f.read(8192):
file_hash.update(chunk)
current_hash = file_hash.hexdigest()
else:
# otherwise get describe_state dict and hash that
current_hash = hashlib.blake2b(json.dumps(self.describe_state(), sort_keys=True).encode()).hexdigest()
# if the previous hash is None, set the current hash to previous # if the previous hash is None, set the current hash to previous
if self.previous_hash is None: if self.previous_hash is None:
@@ -188,7 +160,7 @@ class File(FileSystemItemABC):
self.num_access += 1 # file was accessed self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name path = self.folder.name + "/" + self.name
self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}") self.sys_log.info(f"Repaired file {path}")
return True return True
def corrupt(self) -> bool: def corrupt(self) -> bool:
@@ -203,7 +175,7 @@ class File(FileSystemItemABC):
self.num_access += 1 # file was accessed self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name path = self.folder.name + "/" + self.name
self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}") self.sys_log.info(f"Corrupted file {path}")
return True return True
def restore(self) -> bool: def restore(self) -> bool:
@@ -217,7 +189,7 @@ class File(FileSystemItemABC):
self.num_access += 1 # file was accessed self.num_access += 1 # file was accessed
path = self.folder.name + "/" + self.name path = self.folder.name + "/" + self.name
self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}") self.sys_log.info(f"Restored file {path}")
return True return True
def delete(self) -> bool: def delete(self) -> bool:

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import shutil
from pathlib import Path from pathlib import Path
from typing import Dict, Optional from typing import Dict, Optional
@@ -230,7 +229,6 @@ class FileSystem(SimComponent):
size: Optional[int] = None, size: Optional[int] = None,
file_type: Optional[FileType] = None, file_type: Optional[FileType] = None,
folder_name: Optional[str] = None, folder_name: Optional[str] = None,
real: bool = False,
) -> File: ) -> File:
""" """
Creates a File and adds it to the list of files. Creates a File and adds it to the list of files.
@@ -239,7 +237,6 @@ class FileSystem(SimComponent):
:param size: The size the file takes on disk in bytes. :param size: The size the file takes on disk in bytes.
:param file_type: The type of the file. :param file_type: The type of the file.
:param folder_name: The folder to add the file to. :param folder_name: The folder to add the file to.
:param real: "Indicates whether the File is actually a real file in the Node sim fs output."
""" """
if folder_name: if folder_name:
# check if file with name already exists # check if file with name already exists
@@ -258,8 +255,6 @@ class FileSystem(SimComponent):
file_type=file_type, file_type=file_type,
folder_id=folder.uuid, folder_id=folder.uuid,
folder_name=folder.name, folder_name=folder.name,
real=real,
sim_path=self.sim_root if real else None,
sim_root=self.sim_root, sim_root=self.sim_root,
sys_log=self.sys_log, sys_log=self.sys_log,
) )
@@ -368,11 +363,6 @@ class FileSystem(SimComponent):
# add file to dst # add file to dst
dst_folder.add_file(file) dst_folder.add_file(file)
self.num_file_creations += 1 self.num_file_creations += 1
if file.real:
old_sim_path = file.sim_path
file.sim_path = file.sim_root / file.path
file.sim_path.parent.mkdir(exist_ok=True)
shutil.move(old_sim_path, file.sim_path)
def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str): def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
""" """
@@ -401,9 +391,6 @@ class FileSystem(SimComponent):
dst_folder.add_file(file_copy, force=True) dst_folder.add_file(file_copy, force=True)
if file.real:
file_copy.sim_path.parent.mkdir(exist_ok=True)
shutil.copy2(file.sim_path, file_copy.sim_path)
else: else:
self.sys_log.error(f"Unable to copy file. {src_file_name} does not exist.") self.sys_log.error(f"Unable to copy file. {src_file_name} does not exist.")

View File

@@ -192,7 +192,6 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
# Cannot send Frame as the network interface is not enabled # Cannot send Frame as the network interface is not enabled
return False return False
@abstractmethod
def receive_frame(self, frame: Frame) -> bool: def receive_frame(self, frame: Frame) -> bool:
""" """
Receives a network frame on the network interface. Receives a network frame on the network interface.
@@ -200,7 +199,13 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
:param frame: The network frame being received. :param frame: The network frame being received.
:return: A boolean indicating whether the frame was successfully received. :return: A boolean indicating whether the frame was successfully received.
""" """
pass if self.enabled:
frame.set_sent_timestamp()
self.pcap.capture_inbound(frame)
self._connected_node.receive_frame(frame, self)
return True
# Cannot receive Frame as the network interface is not enabled
return False
class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC): class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC):

View File

@@ -1378,7 +1378,7 @@ class Node(SimComponent):
application_instance.configure(server_ip_address=IPv4Address(ip_address)) application_instance.configure(server_ip_address=IPv4Address(ip_address))
else: else:
pass pass
application_instance.install()
if application_instance.name in self.software_manager.software: if application_instance.name in self.software_manager.software:
return True return True
else: else:

View File

@@ -1,4 +1,7 @@
from typing import ClassVar, Dict
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
class Computer(HostNode): class Computer(HostNode):
@@ -29,4 +32,6 @@ class Computer(HostNode):
* Web Browser * Web Browser
""" """
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient}
pass pass

View File

@@ -10,7 +10,6 @@ from primaite.simulator.network.transmission.data_link_layer import Frame
from primaite.simulator.system.applications.web_browser import WebBrowser from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.arp.arp import ARP, ARPPacket from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.icmp.icmp import ICMP from primaite.simulator.system.services.icmp.icmp import ICMP
from primaite.simulator.system.services.ntp.ntp_client import NTPClient from primaite.simulator.system.services.ntp.ntp_client import NTPClient
from primaite.utils.validators import IPV4Address from primaite.utils.validators import IPV4Address
@@ -301,7 +300,6 @@ class HostNode(Node):
"HostARP": HostARP, "HostARP": HostARP,
"ICMP": ICMP, "ICMP": ICMP,
"DNSClient": DNSClient, "DNSClient": DNSClient,
"FTPClient": FTPClient,
"NTPClient": NTPClient, "NTPClient": NTPClient,
"WebBrowser": WebBrowser, "WebBrowser": WebBrowser,
} }

View File

@@ -1,6 +1,6 @@
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Any, Dict, Set from typing import Any, Dict, Optional, Set
from primaite.interface.request import RequestResponse from primaite.interface.request import RequestResponse
from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.core import RequestManager, RequestType
@@ -33,6 +33,10 @@ class Application(IOSoftware):
"The number of times the application has been executed. Default is 0." "The number of times the application has been executed. Default is 0."
groups: Set[str] = set() groups: Set[str] = set()
"The set of groups to which the application belongs." "The set of groups to which the application belongs."
install_duration: int = 2
"How long it takes to install the application."
install_countdown: Optional[int] = None
"The countdown to the end of the installation process. None if not currently installing"
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
@@ -76,6 +80,12 @@ class Application(IOSoftware):
:param timestep: The current timestep of the simulation. :param timestep: The current timestep of the simulation.
""" """
super().apply_timestep(timestep=timestep) super().apply_timestep(timestep=timestep)
if self.operating_state is ApplicationOperatingState.INSTALLING:
self.install_countdown -= 1
if self.install_countdown <= 0:
self.operating_state = ApplicationOperatingState.RUNNING
self.health_state_actual = SoftwareHealthState.GOOD
self.install_countdown = None
def pre_timestep(self, timestep: int) -> None: def pre_timestep(self, timestep: int) -> None:
"""Apply pre-timestep logic.""" """Apply pre-timestep logic."""
@@ -129,6 +139,7 @@ class Application(IOSoftware):
super().install() super().install()
if self.operating_state == ApplicationOperatingState.CLOSED: if self.operating_state == ApplicationOperatingState.CLOSED:
self.operating_state = ApplicationOperatingState.INSTALLING self.operating_state = ApplicationOperatingState.INSTALLING
self.install_countdown = self.install_duration
def receive(self, payload: Any, session_id: str, **kwargs) -> bool: def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
""" """

View File

@@ -177,4 +177,5 @@ class DoSBot(DatabaseClient):
:param timestep: The timestep value to update the bot's state. :param timestep: The timestep value to update the bot's state.
""" """
super().apply_timestep(timestep=timestep)
self._application_loop() self._application_loop()

View File

@@ -118,14 +118,6 @@ class RansomwareScript(Application):
self.sys_log.info(f"{self.name}: Activated!") self.sys_log.info(f"{self.name}: Activated!")
self.attack_stage = RansomwareAttackStage.ACTIVATE self.attack_stage = RansomwareAttackStage.ACTIVATE
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the bot, triggering the application loop.
:param timestep: The timestep value to update the bot's state.
"""
pass
def run(self) -> bool: def run(self) -> bool:
"""Calls the parent classes execute method before starting the application loop.""" """Calls the parent classes execute method before starting the application loop."""
super().run() super().run()

View File

@@ -126,7 +126,6 @@ class FTPClient(FTPServiceABC):
dest_file_name: str, dest_file_name: str,
dest_port: Optional[Port] = Port.FTP, dest_port: Optional[Port] = Port.FTP,
session_id: Optional[str] = None, session_id: Optional[str] = None,
real_file_path: Optional[str] = None,
) -> bool: ) -> bool:
""" """
Send a file to a target IP address. Send a file to a target IP address.

View File

@@ -1,4 +1,3 @@
import shutil
from abc import ABC from abc import ABC
from ipaddress import IPv4Address from ipaddress import IPv4Address
from typing import Dict, Optional from typing import Dict, Optional
@@ -55,19 +54,17 @@ class FTPServiceABC(Service, ABC):
file_name = payload.ftp_command_args["dest_file_name"] file_name = payload.ftp_command_args["dest_file_name"]
folder_name = payload.ftp_command_args["dest_folder_name"] folder_name = payload.ftp_command_args["dest_folder_name"]
file_size = payload.ftp_command_args["file_size"] file_size = payload.ftp_command_args["file_size"]
real_file_path = payload.ftp_command_args.get("real_file_path")
health_status = payload.ftp_command_args["health_status"] health_status = payload.ftp_command_args["health_status"]
is_real = real_file_path is not None
file = self.file_system.create_file( file = self.file_system.create_file(
file_name=file_name, folder_name=folder_name, size=file_size, real=is_real file_name=file_name,
folder_name=folder_name,
size=file_size,
) )
file.health_status = health_status file.health_status = health_status
self.sys_log.info( self.sys_log.info(
f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/" f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/"
f"{payload.ftp_command_args['dest_file_name']}" f"{payload.ftp_command_args['dest_file_name']}"
) )
if is_real:
shutil.copy(real_file_path, file.sim_path)
# file should exist # file should exist
return self.file_system.get_file(file_name=file_name, folder_name=folder_name) is not None return self.file_system.get_file(file_name=file_name, folder_name=folder_name) is not None
except Exception as e: except Exception as e:
@@ -115,7 +112,6 @@ class FTPServiceABC(Service, ABC):
"dest_folder_name": dest_folder_name, "dest_folder_name": dest_folder_name,
"dest_file_name": dest_file_name, "dest_file_name": dest_file_name,
"file_size": file.sim_size, "file_size": file.sim_size,
"real_file_path": file.sim_path if file.real else None,
"health_status": file.health_status, "health_status": file.health_status,
}, },
packet_payload_size=file.sim_size, packet_payload_size=file.sim_size,

View File

View File

@@ -0,0 +1,171 @@
import click
import typer
from rich import print
from rich.table import Table
from typing_extensions import Annotated
from primaite import _PRIMAITE_ROOT, PRIMAITE_CONFIG
from primaite.simulator import LogLevel
from primaite.utils.cli.primaite_config_utils import is_dev_mode, update_primaite_application_config
dev = typer.Typer()
PRODUCTION_MODE_MESSAGE = (
"\n[green]:rocket::rocket::rocket: "
" PrimAITE is running in Production mode "
" :rocket::rocket::rocket: [/green]\n"
)
DEVELOPER_MODE_MESSAGE = (
"\n[yellow] :construction::construction::construction: "
" PrimAITE is running in Development mode "
" :construction::construction::construction: [/yellow]\n"
)
def dev_mode():
"""
CLI commands relevant to the dev-mode for PrimAITE.
The dev-mode contains tools that help with the ease of developing or debugging PrimAITE.
By default, PrimAITE will be in production mode.
To enable development mode, use `primaite dev-mode enable`
"""
@dev.command()
def show():
"""Show if PrimAITE is in development mode or production mode."""
# print if dev mode is enabled
print(DEVELOPER_MODE_MESSAGE if is_dev_mode() else PRODUCTION_MODE_MESSAGE)
table = Table(title="Current Dev-Mode Settings")
table.add_column("Setting", style="cyan")
table.add_column("Value", style="default")
for setting, value in PRIMAITE_CONFIG["developer_mode"].items():
table.add_row(setting, str(value))
print(table)
print("\nTo see available options, use [cyan]`primaite dev-mode --help`[/cyan]\n")
@dev.command()
def enable():
"""Enable the development mode for PrimAITE."""
# enable dev mode
PRIMAITE_CONFIG["developer_mode"]["enabled"] = True
update_primaite_application_config()
print(DEVELOPER_MODE_MESSAGE)
@dev.command()
def disable():
"""Disable the development mode for PrimAITE."""
# disable dev mode
PRIMAITE_CONFIG["developer_mode"]["enabled"] = False
update_primaite_application_config()
print(PRODUCTION_MODE_MESSAGE)
def config_callback(
ctx: typer.Context,
sys_log_level: Annotated[
LogLevel,
typer.Option(
"--sys-log-level",
"-level",
click_type=click.Choice(LogLevel._member_names_, case_sensitive=False),
help="The level of system logs to output.",
show_default=False,
),
] = None,
output_sys_logs: Annotated[
bool,
typer.Option(
"--output-sys-logs/--no-sys-logs", "-sys/-nsys", help="Output system logs to file.", show_default=False
),
] = None,
output_pcap_logs: Annotated[
bool,
typer.Option(
"--output-pcap-logs/--no-pcap-logs",
"-pcap/-npcap",
help="Output network packet capture logs to file.",
show_default=False,
),
] = None,
output_to_terminal: Annotated[
bool,
typer.Option(
"--output-to-terminal/--no-terminal", "-t/-nt", help="Output system logs to terminal.", show_default=False
),
] = None,
):
"""Configure the development tools and environment."""
if ctx.params.get("sys_log_level") is not None:
PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] = ctx.params.get("sys_log_level")
print(f"PrimAITE dev-mode config updated sys_log_level={ctx.params.get('sys_log_level')}")
if output_sys_logs is not None:
PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] = output_sys_logs
print(f"PrimAITE dev-mode config updated {output_sys_logs=}")
if output_pcap_logs is not None:
PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] = output_pcap_logs
print(f"PrimAITE dev-mode config updated {output_pcap_logs=}")
if output_to_terminal is not None:
PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] = output_to_terminal
print(f"PrimAITE dev-mode config updated {output_to_terminal=}")
# update application config
update_primaite_application_config()
config_typer = typer.Typer(
callback=config_callback,
name="config",
no_args_is_help=True,
invoke_without_command=True,
)
dev.add_typer(config_typer)
@config_typer.command()
def path(
directory: Annotated[
str,
typer.Argument(
help="Directory where the system logs and PCAP logs will be output. By default, this will be where the"
"root of the PrimAITE repository is located.",
show_default=False,
),
] = None,
default: Annotated[
bool,
typer.Option(
"--default",
"-root",
help="Set PrimAITE to output system logs and pcap logs to the PrimAITE repository root.",
),
] = None,
):
"""Set the output directory for the PrimAITE system and PCAP logs."""
if default:
PRIMAITE_CONFIG["developer_mode"]["output_dir"] = None
# update application config
update_primaite_application_config()
print(
f"PrimAITE dev-mode output_dir [cyan]"
f"{str(_PRIMAITE_ROOT.parent.parent / 'simulation_output')}"
f"[/cyan]"
)
return
if directory:
PRIMAITE_CONFIG["developer_mode"]["output_dir"] = directory
# update application config
update_primaite_application_config()
print(f"PrimAITE dev-mode output_dir [cyan]{directory}[/cyan]")

View File

@@ -0,0 +1,22 @@
from typing import Dict, Optional
import yaml
from primaite import PRIMAITE_CONFIG, PRIMAITE_PATHS
def is_dev_mode() -> bool:
"""Returns True if PrimAITE is currently running in developer mode."""
return PRIMAITE_CONFIG.get("developer_mode", {}).get("enabled", False)
def update_primaite_application_config(config: Optional[Dict] = None) -> None:
"""
Update the PrimAITE application config file.
:params: config: Leave empty so that PRIMAITE_CONFIG is used - otherwise provide the Dict
"""
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
if not config:
config = PRIMAITE_CONFIG
yaml.dump(config, file)

View File

@@ -1,11 +0,0 @@
import yaml
from primaite import PRIMAITE_PATHS
def is_dev_mode() -> bool:
"""Returns True if PrimAITE is currently running in developer mode."""
if PRIMAITE_PATHS.app_config_file_path.exists():
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
primaite_config = yaml.safe_load(file)
return primaite_config["developer_mode"]

View File

@@ -1,11 +1,8 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from datetime import datetime from typing import Any, Dict, Tuple
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
import pytest import pytest
import yaml import yaml
from _pytest.monkeypatch import MonkeyPatch
from primaite import getLogger, PRIMAITE_PATHS from primaite import getLogger, PRIMAITE_PATHS
from primaite.game.agent.actions import ActionManager from primaite.game.agent.actions import ActionManager
@@ -13,7 +10,6 @@ from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction from primaite.game.agent.rewards import RewardFunction
from primaite.game.game import PrimaiteGame from primaite.game.game import PrimaiteGame
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.container import Network from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.computer import Computer
@@ -32,7 +28,6 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.services.web_server.web_server import WebServer
from tests import TEST_ASSETS_ROOT from tests import TEST_ASSETS_ROOT
from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path
ACTION_SPACE_NODE_VALUES = 1 ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1 ACTION_SPACE_NODE_ACTION_VALUES = 1
@@ -40,21 +35,6 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__) _LOGGER = getLogger(__name__)
@pytest.fixture(scope="function", autouse=True)
def set_syslog_output_to_true():
"""Will be run before each test."""
monkeypatch = MonkeyPatch()
monkeypatch.setattr(
SIM_OUTPUT,
"path",
Path(TEST_ASSETS_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")),
)
monkeypatch.setattr(SIM_OUTPUT, "save_pcap_logs", False)
monkeypatch.setattr(SIM_OUTPUT, "save_sys_logs", False)
yield
class TestService(Service): class TestService(Service):
"""Test Service class""" """Test Service class"""
@@ -86,7 +66,10 @@ class TestApplication(Application):
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def uc2_network() -> Network: def uc2_network() -> Network:
return arcd_uc2_network() with open(PRIMAITE_PATHS.user_config_path / "example_config" / "data_manipulation.yaml") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
return game.simulation.network
@pytest.fixture(scope="function") @pytest.fixture(scope="function")

View File

@@ -38,3 +38,5 @@ def test_rllib_single_agent_compatibility():
save_file = Path(tempfile.gettempdir()) / "ray/" save_file = Path(tempfile.gettempdir()) / "ray/"
algo.save(save_file) algo.save(save_file)
assert save_file.exists() assert save_file.exists()
save_file.unlink() # clean up

View File

@@ -25,3 +25,4 @@ def test_sb3_compatibility():
model.save(save_path) model.save(save_path)
assert (save_path).exists() assert (save_path).exists()
save_path.unlink() # clean up

View File

@@ -26,6 +26,9 @@ def test_data_manipulation(uc2_network):
# First check that the DB client on the web_server can successfully query the users table on the database # First check that the DB client on the web_server can successfully query the users table on the database
assert db_connection.query("SELECT") assert db_connection.query("SELECT")
db_manipulation_bot.data_manipulation_p_of_success = 1.0
db_manipulation_bot.port_scan_p_of_success = 1.0
# Now we run the DataManipulationBot # Now we run the DataManipulationBot
db_manipulation_bot.attack() db_manipulation_bot.attack()
@@ -59,6 +62,11 @@ def test_application_install_uninstall_on_uc2():
_, _, _, _, info = env.step(78) _, _, _, _, info = env.step(78)
assert "DoSBot" in domcon.software_manager.software assert "DoSBot" in domcon.software_manager.software
# installing takes 3 steps so let's wait for 3 steps
env.step(0)
env.step(0)
env.step(0)
# Test we can now execute the DoSBot app # Test we can now execute the DoSBot app
_, _, _, _, info = env.step(81) _, _, _, _, info = env.step(81)
assert info["agent_actions"]["defender"].response.status == "success" assert info["agent_actions"]["defender"].response.status == "success"

View File

@@ -0,0 +1,11 @@
from typing import List
from typer.testing import CliRunner, Result
from primaite.cli import app
def cli(args: List[str]) -> Result:
"""Pass in a list of arguments and it will return the result."""
runner = CliRunner()
return runner.invoke(app, args)

View File

@@ -0,0 +1,171 @@
import os
import shutil
import tempfile
from pathlib import Path
import pkg_resources
import pytest
import yaml
from primaite import PRIMAITE_CONFIG
from primaite.utils.cli.primaite_config_utils import update_primaite_application_config
from tests.integration_tests.cli import cli
@pytest.fixture(autouse=True)
def test_setup():
"""
Setup this test by using the default primaite app config in package
"""
global PRIMAITE_CONFIG
current_config = PRIMAITE_CONFIG.copy() # store the config before test
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
with open(pkg_config_path, "r") as file:
# load from config
config_dict = yaml.safe_load(file)
PRIMAITE_CONFIG["developer_mode"] = config_dict["developer_mode"]
yield
PRIMAITE_CONFIG["developer_mode"] = current_config["developer_mode"] # restore config to prevent being yelled at
update_primaite_application_config(config=PRIMAITE_CONFIG)
def test_dev_mode_enable_disable():
"""Test dev mode enable and disable."""
# check defaults
assert PRIMAITE_CONFIG["developer_mode"]["enabled"] is False # not enabled by default
result = cli(["dev-mode", "show"])
assert "Production" in result.output # should print that it is in Production mode by default
result = cli(["dev-mode", "enable"])
assert "Development" in result.output # should print that it is in Development mode
assert PRIMAITE_CONFIG["developer_mode"]["enabled"] # config should reflect that dev mode is enabled
result = cli(["dev-mode", "show"])
assert "Development" in result.output # should print that it is in Development mode
result = cli(["dev-mode", "disable"])
assert "Production" in result.output # should print that it is in Production mode
assert PRIMAITE_CONFIG["developer_mode"]["enabled"] is False # config should reflect that dev mode is disabled
result = cli(["dev-mode", "show"])
assert "Production" in result.output # should print that it is in Production mode
def test_dev_mode_config_sys_log_level():
"""Check that the system log level can be changed via CLI."""
# check defaults
assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "DEBUG" # DEBUG by default
result = cli(["dev-mode", "config", "-level", "WARNING"])
assert "sys_log_level=WARNING" in result.output # should print correct value
# config should reflect that log level is WARNING
assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "WARNING"
result = cli(["dev-mode", "config", "--sys-log-level", "INFO"])
assert "sys_log_level=INFO" in result.output # should print correct value
# config should reflect that log level is WARNING
assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "INFO"
def test_dev_mode_config_sys_logs_enable_disable():
"""Test that the system logs output can be enabled or disabled."""
# check defaults
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] is False # False by default
result = cli(["dev-mode", "config", "--output-sys-logs"])
assert "output_sys_logs=True" in result.output # should print correct value
# config should reflect that output_sys_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"]
result = cli(["dev-mode", "config", "--no-sys-logs"])
assert "output_sys_logs=False" in result.output # should print correct value
# config should reflect that output_sys_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] is False
result = cli(["dev-mode", "config", "-sys"])
assert "output_sys_logs=True" in result.output # should print correct value
# config should reflect that output_sys_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"]
result = cli(["dev-mode", "config", "-nsys"])
assert "output_sys_logs=False" in result.output # should print correct value
# config should reflect that output_sys_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] is False
def test_dev_mode_config_pcap_logs_enable_disable():
"""Test that the pcap logs output can be enabled or disabled."""
# check defaults
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] is False # False by default
result = cli(["dev-mode", "config", "--output-pcap-logs"])
assert "output_pcap_logs=True" in result.output # should print correct value
# config should reflect that output_pcap_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"]
result = cli(["dev-mode", "config", "--no-pcap-logs"])
assert "output_pcap_logs=False" in result.output # should print correct value
# config should reflect that output_pcap_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] is False
result = cli(["dev-mode", "config", "-pcap"])
assert "output_pcap_logs=True" in result.output # should print correct value
# config should reflect that output_pcap_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"]
result = cli(["dev-mode", "config", "-npcap"])
assert "output_pcap_logs=False" in result.output # should print correct value
# config should reflect that output_pcap_logs is True
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] is False
def test_dev_mode_config_output_to_terminal_enable_disable():
"""Test that the output to terminal can be enabled or disabled."""
# check defaults
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] is False # False by default
result = cli(["dev-mode", "config", "--output-to-terminal"])
assert "output_to_terminal=True" in result.output # should print correct value
# config should reflect that output_to_terminal is True
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"]
result = cli(["dev-mode", "config", "--no-terminal"])
assert "output_to_terminal=False" in result.output # should print correct value
# config should reflect that output_to_terminal is True
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] is False
result = cli(["dev-mode", "config", "-t"])
assert "output_to_terminal=True" in result.output # should print correct value
# config should reflect that output_to_terminal is True
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"]
result = cli(["dev-mode", "config", "-nt"])
assert "output_to_terminal=False" in result.output # should print correct value
# config should reflect that output_to_terminal is True
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] is False

View File

@@ -83,7 +83,7 @@ def test_sometech_webserver_cannot_access_ftp_on_sometech_storage_server():
some_tech_storage_srv.file_system.create_file(file_name="test.png") some_tech_storage_srv.file_system.create_file(file_name="test.png")
web_server: Server = network.get_node_by_hostname("some_tech_web_srv") web_server: Server = network.get_node_by_hostname("some_tech_web_srv")
web_server.software_manager.install(FTPClient)
web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"] web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"]
assert not web_ftp_client.request_file( assert not web_ftp_client.request_file(

View File

@@ -101,7 +101,7 @@ def test_database_client_native_connection_query(uc2_network):
"""Tests DB query across the network returns HTTP status 200 and date.""" """Tests DB query across the network returns HTTP status 200 and date."""
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
assert db_client.query(sql="SELECT") assert db_client.query(sql="SELECT")
assert db_client.query(sql="INSERT") assert db_client.query(sql="INSERT")
@@ -222,6 +222,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server") web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient") db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_client.connect()
assert len(db_client.client_connections) assert len(db_client.client_connections)
# Establish a new connection to the DatabaseService # Establish a new connection to the DatabaseService

View File

@@ -57,20 +57,6 @@ def test_simulated_file_check_hash(file_system):
assert file.health_status == FileSystemItemHealthStatus.CORRUPT assert file.health_status == FileSystemItemHealthStatus.CORRUPT
@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented")
def test_real_file_check_hash(file_system):
file: File = file_system.create_file(file_name="test_file.txt", real=True)
file.check_hash()
assert file.health_status == FileSystemItemHealthStatus.GOOD
# change file content
with open(file.sim_path, "a") as f:
f.write("get hacked scrub lol xD\n")
file.check_hash()
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
def test_file_corrupt_repair_restore(file_system): def test_file_corrupt_repair_restore(file_system):
"""Test the ability to corrupt and repair files.""" """Test the ability to corrupt and repair files."""
file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder")

View File

@@ -191,7 +191,7 @@ def test_copy_file(file_system):
file_system.create_folder(folder_name="src_folder") file_system.create_folder(folder_name="src_folder")
file_system.create_folder(folder_name="dst_folder") file_system.create_folder(folder_name="dst_folder")
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True) file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder")
assert file_system.num_file_creations == 1 assert file_system.num_file_creations == 1
original_uuid = file.uuid original_uuid = file.uuid

View File

@@ -132,21 +132,3 @@ def test_simulated_folder_check_hash(file_system):
file.sim_size = 0 file.sim_size = 0
folder.check_hash() folder.check_hash()
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented")
def test_real_folder_check_hash(file_system):
folder: Folder = file_system.create_folder(folder_name="test_folder")
file_system.create_file(file_name="test_file.txt", folder_name="test_folder", real=True)
folder.check_hash()
assert folder.health_status == FileSystemItemHealthStatus.GOOD
# change simulated file size
file = folder.get_file(file_name="test_file.txt")
# change file content
with open(file.sim_path, "a") as f:
f.write("get hacked scrub lol xD\n")
folder.check_hash()
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT

View File

@@ -2,10 +2,20 @@ from uuid import uuid4
import pytest import pytest
from primaite import PRIMAITE_CONFIG
from primaite.simulator import LogLevel, SIM_OUTPUT from primaite.simulator import LogLevel, SIM_OUTPUT
from primaite.simulator.system.core.sys_log import SysLog from primaite.simulator.system.core.sys_log import SysLog
@pytest.fixture(autouse=True)
def override_dev_mode_temporarily():
"""Temporarily turn off dev mode for this test."""
primaite_dev_mode = PRIMAITE_CONFIG["developer_mode"]["enabled"]
PRIMAITE_CONFIG["developer_mode"]["enabled"] = False
yield # run tests
PRIMAITE_CONFIG["developer_mode"]["enabled"] = primaite_dev_mode
@pytest.fixture(scope="function") @pytest.fixture(scope="function")
def syslog() -> SysLog: def syslog() -> SysLog:
return SysLog(hostname="test") return SysLog(hostname="test")