Merge branch 'dev' into feature/2457-Set_link_bandwidth_via_config
This commit is contained in:
@@ -14,13 +14,13 @@ parameters:
|
||||
- name: matrix
|
||||
type: object
|
||||
default:
|
||||
# - job_name: 'UbuntuPython38'
|
||||
# py: '3.8'
|
||||
# img: 'ubuntu-latest'
|
||||
# every_time: false
|
||||
# publish_coverage: false
|
||||
- job_name: 'UbuntuPython310'
|
||||
py: '3.10'
|
||||
- job_name: 'UbuntuPython38'
|
||||
py: '3.8'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'UbuntuPython311'
|
||||
py: '3.11'
|
||||
img: 'ubuntu-latest'
|
||||
every_time: true
|
||||
publish_coverage: true
|
||||
@@ -29,8 +29,8 @@ parameters:
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'WindowsPython310'
|
||||
py: '3.10'
|
||||
- job_name: 'WindowsPython311'
|
||||
py: '3.11'
|
||||
img: 'windows-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
@@ -39,8 +39,8 @@ parameters:
|
||||
img: 'macOS-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
- job_name: 'MacOSPython310'
|
||||
py: '3.10'
|
||||
- job_name: 'MacOSPython311'
|
||||
py: '3.11'
|
||||
img: 'macOS-latest'
|
||||
every_time: false
|
||||
publish_coverage: false
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -82,6 +82,7 @@ target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
PPO_UC2/
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
@@ -150,6 +151,7 @@ docs/source/primaite-dependencies.rst
|
||||
# outputs
|
||||
src/primaite/outputs/
|
||||
simulation_output/
|
||||
sessions/
|
||||
|
||||
# benchmark session outputs
|
||||
benchmark/output
|
||||
|
||||
@@ -14,14 +14,18 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Added ability to define scenarios that change depending on the episode number.
|
||||
- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config`
|
||||
- Database Connection ID's are now created/issued by DatabaseService and not DatabaseClient
|
||||
- added ability to set PrimAITE between development and production modes via PrimAITE CLI ``mode`` command
|
||||
- Updated DatabaseClient so that it can now have a single native DatabaseClientConnection along with a collection of DatabaseClientConnection's.
|
||||
- Implemented the uninstall functionality for DatabaseClient so that all connections are terminated at the DatabaseService.
|
||||
- Added the ability for a DatabaseService to terminate a connection.
|
||||
- Added active_connection to DatabaseClientConnection so that if the connection is terminated active_connection is set to False and the object can no longer be used.
|
||||
- Added additional show functions to enable connection inspection.
|
||||
- Updates to agent logging, to include the reward both per step and per episode.
|
||||
- Introduced Developer CLI tools to assist with developing/debugging PrimAITE
|
||||
- Can be enabled via `primaite dev-mode enable`
|
||||
- Activating dev-mode will change the location where the sessions will be output - by default will output where the PrimAITE repository is located
|
||||
- Refactored all air-space usage to that a new instance of AirSpace is created for each instance of Network. This 1:1 relationship between network and airspace will allow parallelization.
|
||||
- Added notebook to demonstrate use of SubprocVecEnv from SB3 to vectorise environments to speed up training.
|
||||
|
||||
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
@@ -116,6 +116,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
|
||||
:caption: Developer information:
|
||||
:hidden:
|
||||
|
||||
source/developer_tools
|
||||
source/state_system
|
||||
source/request_system
|
||||
PrimAITE API <source/_autosummary/primaite>
|
||||
|
||||
210
docs/source/developer_tools.rst
Normal file
210
docs/source/developer_tools.rst
Normal file
@@ -0,0 +1,210 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _Developer Tools:
|
||||
|
||||
Developer Tools
|
||||
***************
|
||||
|
||||
PrimAITE includes developer CLI tools that are intended to be used by developers.
|
||||
|
||||
dev-mode
|
||||
========
|
||||
|
||||
The dev-mode contains configuration which override any of the config files during runtime.
|
||||
|
||||
This is intended to make debugging easier by removing the need to find the relevant configuration file/settings.
|
||||
|
||||
Enabling dev-mode
|
||||
-----------------
|
||||
|
||||
The PrimAITE dev-mode can be enabled via the use of
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode enable
|
||||
|
||||
Disabling dev-mode
|
||||
------------------
|
||||
|
||||
The PrimAITE dev-mode can be disabled via the use of
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode disable
|
||||
|
||||
Show current mode
|
||||
-----------------
|
||||
|
||||
To show if the dev-mode is enabled or not, use
|
||||
The PrimAITE dev-mode can be disabled via the use of
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode show
|
||||
|
||||
dev-mode configuration
|
||||
======================
|
||||
|
||||
The following configures some specific items that the dev-mode overrides, if enabled.
|
||||
|
||||
`--sys-log-level` or `-level`
|
||||
----------------------------
|
||||
|
||||
The level of system logs can be overridden by dev-mode.
|
||||
|
||||
By default, this is set to DEBUG
|
||||
|
||||
The available options are [DEBUG|INFO|WARNING|ERROR|CRITICAL]
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config -level INFO
|
||||
|
||||
or
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config --sys-log-level INFO
|
||||
|
||||
`--output-sys-logs` or `-sys`
|
||||
-----------------------------
|
||||
|
||||
The outputting of system logs can be overridden by dev-mode.
|
||||
|
||||
By default, this is set to False
|
||||
|
||||
Enabling system logs
|
||||
""""""""""""""""""""
|
||||
|
||||
To enable outputting of system logs
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config --output-sys-logs
|
||||
|
||||
or
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config -sys
|
||||
|
||||
Disabling system logs
|
||||
"""""""""""""""""""""
|
||||
|
||||
To disable outputting of system logs
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config --no-sys-logs
|
||||
|
||||
or
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config -nsys
|
||||
|
||||
`--output-pcap-logs` or `-pcap`
|
||||
-------------------------------
|
||||
|
||||
The outputting of packet capture logs can be overridden by dev-mode.
|
||||
|
||||
By default, this is set to False
|
||||
|
||||
Enabling PCAP logs
|
||||
""""""""""""""""""
|
||||
|
||||
To enable outputting of packet capture logs
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config --output-pcap-logs
|
||||
|
||||
or
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config -pcap
|
||||
|
||||
Disabling PCAP logs
|
||||
"""""""""""""""""""
|
||||
|
||||
To disable outputting of packet capture logs
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config --no-pcap-logs
|
||||
|
||||
or
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config -npcap
|
||||
|
||||
`--output-to-terminal` or `-t`
|
||||
------------------------------
|
||||
|
||||
The outputting of system logs to the terminal can be overridden by dev-mode.
|
||||
|
||||
By default, this is set to False
|
||||
|
||||
Enabling system log output to terminal
|
||||
""""""""""""""""""""""""""""""""""""""
|
||||
|
||||
To enable outputting of system logs to terminal
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config --output-to-terminal
|
||||
|
||||
or
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config -t
|
||||
|
||||
Disabling system log output to terminal
|
||||
"""""""""""""""""""""""""""""""""""""""
|
||||
|
||||
To disable outputting of system logs to terminal
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config --no-terminal
|
||||
|
||||
or
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config -nt
|
||||
|
||||
path
|
||||
----
|
||||
|
||||
PrimAITE dev-mode can override where sessions are output.
|
||||
|
||||
By default, PrimAITE will output the sessions in USER_HOME/primaite/sessions
|
||||
|
||||
With dev-mode enabled, by default, this will be changed to PRIMAITE_REPOSITORY_ROOT/sessions
|
||||
|
||||
However, providing a path will let dev-mode output sessions to the given path e.g.
|
||||
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
primaite dev-mode config path ~/output/path
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
primaite dev-mode config path ~\output\path
|
||||
|
||||
default path
|
||||
""""""""""""
|
||||
|
||||
To reset the path to use the PRIMAITE_REPOSITORY_ROOT/sessions, run the command
|
||||
|
||||
.. code-block::
|
||||
|
||||
primaite dev-mode config path --default
|
||||
@@ -161,9 +161,11 @@ To set PrimAITE to run in development mode:
|
||||
.. code-block:: bash
|
||||
:caption: Unix
|
||||
|
||||
primaite mode --dev
|
||||
primaite dev-mode enable
|
||||
|
||||
.. code-block:: powershell
|
||||
:caption: Windows (Powershell)
|
||||
|
||||
primaite mode --dev
|
||||
primaite dev-mode enable
|
||||
|
||||
More information about :ref:`Developer Tools`
|
||||
|
||||
@@ -7,7 +7,7 @@ name = "primaite"
|
||||
description = "PrimAITE (Primary-level AI Training Environment) is a simulation environment for training AI under the ARCD programme."
|
||||
authors = [{name="Defence Science and Technology Laboratory UK", email="oss@dstl.gov.uk"}]
|
||||
license = {file = "LICENSE"}
|
||||
requires-python = ">=3.8, <3.11"
|
||||
requires-python = ">=3.8, <3.12"
|
||||
dynamic = ["version", "readme"]
|
||||
classifiers = [
|
||||
"License :: OSI Approved :: MIT License",
|
||||
@@ -20,6 +20,7 @@ classifiers = [
|
||||
"Programming Language :: Python :: 3.8",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3 :: Only",
|
||||
]
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
3.0.0b9dev
|
||||
3.0.0b9
|
||||
|
||||
@@ -122,35 +122,20 @@ class _PrimaitePaths:
|
||||
PRIMAITE_PATHS: Final[_PrimaitePaths] = _PrimaitePaths()
|
||||
|
||||
|
||||
def _host_primaite_config() -> None:
|
||||
if not PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
shutil.copy2(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
|
||||
|
||||
|
||||
_host_primaite_config()
|
||||
|
||||
|
||||
def _get_primaite_config() -> Dict:
|
||||
config_path = PRIMAITE_PATHS.app_config_file_path
|
||||
if not config_path.exists():
|
||||
# load from package if config does not exist
|
||||
config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
# generate app config
|
||||
shutil.copy2(config_path, PRIMAITE_PATHS.app_config_file_path)
|
||||
with open(config_path, "r") as file:
|
||||
# load from config
|
||||
primaite_config = yaml.safe_load(file)
|
||||
log_level_map = {
|
||||
"NOTSET": logging.NOTSET,
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARN": logging.WARN,
|
||||
"WARNING": logging.WARN,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
primaite_config["log_level"] = log_level_map[primaite_config["logging"]["log_level"]]
|
||||
return primaite_config
|
||||
return primaite_config
|
||||
|
||||
|
||||
_PRIMAITE_CONFIG = _get_primaite_config()
|
||||
PRIMAITE_CONFIG = _get_primaite_config()
|
||||
|
||||
|
||||
class _LevelFormatter(Formatter):
|
||||
@@ -177,11 +162,11 @@ class _LevelFormatter(Formatter):
|
||||
|
||||
_LEVEL_FORMATTER: Final[_LevelFormatter] = _LevelFormatter(
|
||||
{
|
||||
logging.DEBUG: _PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
|
||||
logging.INFO: _PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
|
||||
logging.WARNING: _PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
|
||||
logging.ERROR: _PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: _PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
|
||||
logging.DEBUG: PRIMAITE_CONFIG["logging"]["logger_format"]["DEBUG"],
|
||||
logging.INFO: PRIMAITE_CONFIG["logging"]["logger_format"]["INFO"],
|
||||
logging.WARNING: PRIMAITE_CONFIG["logging"]["logger_format"]["WARNING"],
|
||||
logging.ERROR: PRIMAITE_CONFIG["logging"]["logger_format"]["ERROR"],
|
||||
logging.CRITICAL: PRIMAITE_CONFIG["logging"]["logger_format"]["CRITICAL"],
|
||||
}
|
||||
)
|
||||
|
||||
@@ -193,10 +178,10 @@ _FILE_HANDLER: Final[RotatingFileHandler] = RotatingFileHandler(
|
||||
backupCount=9, # Max 100MB of logs
|
||||
encoding="utf8",
|
||||
)
|
||||
_STREAM_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
_FILE_HANDLER.setLevel(_PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
_STREAM_HANDLER.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
_FILE_HANDLER.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
|
||||
_LOG_FORMAT_STR: Final[str] = _PRIMAITE_CONFIG["logging"]["logger_format"]
|
||||
_LOG_FORMAT_STR: Final[str] = PRIMAITE_CONFIG["logging"]["logger_format"]
|
||||
_STREAM_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
_FILE_HANDLER.setFormatter(_LEVEL_FORMATTER)
|
||||
|
||||
@@ -215,6 +200,6 @@ def getLogger(name: str) -> Logger: # noqa
|
||||
logging config.
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(_PRIMAITE_CONFIG["log_level"])
|
||||
logger.setLevel(PRIMAITE_CONFIG["logging"]["log_level"])
|
||||
|
||||
return logger
|
||||
|
||||
@@ -2,16 +2,21 @@
|
||||
"""Provides a CLI using Typer as an entry point."""
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pkg_resources
|
||||
import typer
|
||||
import yaml
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.utils.cli import dev_cli
|
||||
|
||||
app = typer.Typer(no_args_is_help=True)
|
||||
app.add_typer(dev_cli.dev, name="dev-mode")
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -89,7 +94,7 @@ def version() -> None:
|
||||
|
||||
|
||||
@app.command()
|
||||
def setup(overwrite_existing: bool = True) -> None:
|
||||
def setup(overwrite_existing: bool = False) -> None:
|
||||
"""
|
||||
Perform the PrimAITE first-time setup.
|
||||
|
||||
@@ -102,11 +107,14 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
|
||||
_LOGGER.info("Performing the PrimAITE first-time setup...")
|
||||
|
||||
_LOGGER.info("Building primaite_config.yaml...")
|
||||
|
||||
_LOGGER.info("Building the PrimAITE app directories...")
|
||||
PRIMAITE_PATHS.mkdirs()
|
||||
|
||||
_LOGGER.info("Building primaite_config.yaml...")
|
||||
if overwrite_existing:
|
||||
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
shutil.copy(pkg_config_path, PRIMAITE_PATHS.app_config_file_path)
|
||||
|
||||
_LOGGER.info("Rebuilding the demo notebooks...")
|
||||
reset_demo_notebooks.run(overwrite_existing=True)
|
||||
|
||||
@@ -114,47 +122,3 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def mode(
|
||||
dev: Annotated[bool, typer.Option("--dev", help="Activates PrimAITE developer mode")] = None,
|
||||
prod: Annotated[bool, typer.Option("--prod", help="Activates PrimAITE production mode")] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Switch PrimAITE between developer mode and production mode.
|
||||
|
||||
By default, PrimAITE will be in production mode.
|
||||
|
||||
To view the current mode, use: primaite mode
|
||||
|
||||
To set to development mode, use: primaite mode --dev
|
||||
|
||||
To return to production mode, use: primaite mode --prod
|
||||
"""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
if dev and prod:
|
||||
print("Unable to activate developer and production modes concurrently.")
|
||||
return
|
||||
|
||||
if (dev is None) and (prod is None):
|
||||
is_dev_mode = primaite_config["developer_mode"]
|
||||
|
||||
if is_dev_mode:
|
||||
print("PrimAITE is running in developer mode.")
|
||||
else:
|
||||
print("PrimAITE is running in production mode.")
|
||||
if dev:
|
||||
# activate dev mode
|
||||
primaite_config["developer_mode"] = True
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print("PrimAITE is running in developer mode.")
|
||||
if prod:
|
||||
# activate prod mode
|
||||
primaite_config["developer_mode"] = False
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
yaml.dump(primaite_config, file)
|
||||
print("PrimAITE is running in production mode.")
|
||||
|
||||
@@ -45,15 +45,9 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(cfg['agents'][2]['agent_settings'])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for agent in cfg['agents']:\n",
|
||||
" if agent[\"ref\"] == \"defender\":\n",
|
||||
" agent['agent_settings']['flatten_obs'] = True\n",
|
||||
"env_config = cfg\n",
|
||||
"\n",
|
||||
"config = (\n",
|
||||
@@ -80,7 +74,7 @@
|
||||
"tune.Tuner(\n",
|
||||
" \"PPO\",\n",
|
||||
" run_config=air.RunConfig(\n",
|
||||
" stop={\"timesteps_total\": 512}\n",
|
||||
" stop={\"timesteps_total\": 1e3 * 128}\n",
|
||||
" ),\n",
|
||||
" param_space=config\n",
|
||||
").fit()\n"
|
||||
|
||||
148
src/primaite/notebooks/multi-processing.ipynb
Normal file
148
src/primaite/notebooks/multi-processing.ipynb
Normal file
@@ -0,0 +1,148 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Simple multi-processing demo using SubprocVecEnv from SB3\n",
|
||||
"Based on a code example provided by Rachael Proctor."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Import packages and read config file."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import yaml\n",
|
||||
"from stable_baselines3 import PPO\n",
|
||||
"from stable_baselines3.common.utils import set_random_seed\n",
|
||||
"from stable_baselines3.common.vec_env import SubprocVecEnv\n",
|
||||
"\n",
|
||||
"from primaite.session.environment import PrimaiteGymEnv\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.config.load import data_manipulation_config_path"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"with open(data_manipulation_config_path(), 'r') as f:\n",
|
||||
" cfg = yaml.safe_load(f)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Set up training data."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"EPISODE_LEN = 128\n",
|
||||
"NUM_EPISODES = 10\n",
|
||||
"NO_STEPS = EPISODE_LEN * NUM_EPISODES\n",
|
||||
"BATCH_SIZE = 32\n",
|
||||
"LEARNING_RATE = 3e-4\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Define an environment function."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"def make_env(rank: int, seed: int = 0) -> callable:\n",
|
||||
" \"\"\"Wrapper script for _init function.\"\"\"\n",
|
||||
"\n",
|
||||
" def _init() -> PrimaiteGymEnv:\n",
|
||||
" env = PrimaiteGymEnv(env_config=cfg)\n",
|
||||
" env.reset(seed=seed + rank)\n",
|
||||
" model = PPO(\n",
|
||||
" \"MlpPolicy\",\n",
|
||||
" env,\n",
|
||||
" learning_rate=LEARNING_RATE,\n",
|
||||
" n_steps=NO_STEPS,\n",
|
||||
" batch_size=BATCH_SIZE,\n",
|
||||
" verbose=0,\n",
|
||||
" tensorboard_log=\"./PPO_UC2/\",\n",
|
||||
" )\n",
|
||||
" model.learn(total_timesteps=NO_STEPS)\n",
|
||||
" return env\n",
|
||||
"\n",
|
||||
" set_random_seed(seed)\n",
|
||||
" return _init\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Run experiment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"n_procs = 2\n",
|
||||
"train_env = SubprocVecEnv([make_env(i + n_procs) for i in range(n_procs)])\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": ".venv",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -5,9 +5,9 @@ from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite import _PRIMAITE_ROOT, getLogger, PRIMAITE_CONFIG, PRIMAITE_PATHS
|
||||
from primaite.simulator import LogLevel, SIM_OUTPUT
|
||||
from primaite.utils.primaite_config_utils import is_dev_mode
|
||||
from primaite.utils.cli.primaite_config_utils import is_dev_mode
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
@@ -62,12 +62,15 @@ class PrimaiteIO:
|
||||
date_str = timestamp.strftime("%Y-%m-%d")
|
||||
time_str = timestamp.strftime("%H-%M-%S")
|
||||
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
|
||||
# check if running in dev mode
|
||||
if is_dev_mode():
|
||||
# if dev mode, simulation output will be the current working directory
|
||||
session_path = Path.cwd() / "simulation_output" / date_str / time_str
|
||||
else:
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
session_path = _PRIMAITE_ROOT.parent.parent / "sessions" / date_str / time_str
|
||||
|
||||
# check if there is an output directory set in config
|
||||
if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
|
||||
session_path = Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"]) / "sessions" / date_str / time_str
|
||||
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
return session_path
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
# The main PrimAITE application config file
|
||||
|
||||
developer_mode: False # false by default
|
||||
developer_mode:
|
||||
enabled: False # not enabled by default
|
||||
sys_log_level: DEBUG # level of output for system logs, DEBUG by default
|
||||
output_sys_logs: False # system logs not output by default
|
||||
output_pcap_logs: False # pcap logs not output by default
|
||||
output_to_terminal: False # do not output to terminal by default
|
||||
output_dir: null # none by default - none will print to repository root
|
||||
|
||||
# Logging
|
||||
logging:
|
||||
|
||||
@@ -3,10 +3,12 @@ from datetime import datetime
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import _PRIMAITE_ROOT
|
||||
from primaite import _PRIMAITE_ROOT, PRIMAITE_CONFIG, PRIMAITE_PATHS
|
||||
|
||||
__all__ = ["SIM_OUTPUT"]
|
||||
|
||||
from primaite.utils.cli.primaite_config_utils import is_dev_mode
|
||||
|
||||
|
||||
class LogLevel(IntEnum):
|
||||
"""Enum containing all the available log levels for PrimAITE simulation output."""
|
||||
@@ -25,16 +27,34 @@ class LogLevel(IntEnum):
|
||||
|
||||
class _SimOutput:
|
||||
def __init__(self):
|
||||
self._path: Path = (
|
||||
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
)
|
||||
self.save_pcap_logs: bool = False
|
||||
self.save_sys_logs: bool = False
|
||||
self.write_sys_log_to_terminal: bool = False
|
||||
self.sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
time_str = datetime.now().strftime("%H-%M-%S")
|
||||
|
||||
path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
|
||||
self._path = path
|
||||
self._save_pcap_logs: bool = False
|
||||
self._save_sys_logs: bool = False
|
||||
self._write_sys_log_to_terminal: bool = False
|
||||
self._sys_log_level: LogLevel = LogLevel.WARNING # default log level is at WARNING
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
if is_dev_mode():
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
time_str = datetime.now().strftime("%H-%M-%S")
|
||||
# if dev mode is enabled, if output dir is not set, print to primaite repo root
|
||||
path: Path = _PRIMAITE_ROOT.parent.parent / "sessions" / date_str / time_str / "simulation_output"
|
||||
# otherwise print to output dir
|
||||
if PRIMAITE_CONFIG["developer_mode"]["output_dir"]:
|
||||
path: Path = (
|
||||
Path(PRIMAITE_CONFIG["developer_mode"]["output_dir"])
|
||||
/ "sessions"
|
||||
/ date_str
|
||||
/ time_str
|
||||
/ "simulation_output"
|
||||
)
|
||||
self._path = path
|
||||
return self._path
|
||||
|
||||
@path.setter
|
||||
@@ -42,5 +62,45 @@ class _SimOutput:
|
||||
self._path = new_path
|
||||
self._path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@property
|
||||
def save_pcap_logs(self) -> bool:
|
||||
if is_dev_mode():
|
||||
return PRIMAITE_CONFIG.get("developer_mode").get("output_pcap_logs")
|
||||
return self._save_pcap_logs
|
||||
|
||||
@save_pcap_logs.setter
|
||||
def save_pcap_logs(self, save_pcap_logs: bool) -> None:
|
||||
self._save_pcap_logs = save_pcap_logs
|
||||
|
||||
@property
|
||||
def save_sys_logs(self) -> bool:
|
||||
if is_dev_mode():
|
||||
return PRIMAITE_CONFIG.get("developer_mode").get("output_sys_logs")
|
||||
return self._save_sys_logs
|
||||
|
||||
@save_sys_logs.setter
|
||||
def save_sys_logs(self, save_sys_logs: bool) -> None:
|
||||
self._save_sys_logs = save_sys_logs
|
||||
|
||||
@property
|
||||
def write_sys_log_to_terminal(self) -> bool:
|
||||
if is_dev_mode():
|
||||
return PRIMAITE_CONFIG.get("developer_mode").get("output_to_terminal")
|
||||
return self._write_sys_log_to_terminal
|
||||
|
||||
@write_sys_log_to_terminal.setter
|
||||
def write_sys_log_to_terminal(self, write_sys_log_to_terminal: bool) -> None:
|
||||
self._write_sys_log_to_terminal = write_sys_log_to_terminal
|
||||
|
||||
@property
|
||||
def sys_log_level(self) -> LogLevel:
|
||||
if is_dev_mode():
|
||||
return LogLevel[PRIMAITE_CONFIG.get("developer_mode").get("sys_log_level")]
|
||||
return self._sys_log_level
|
||||
|
||||
@sys_log_level.setter
|
||||
def sys_log_level(self, sys_log_level: LogLevel) -> None:
|
||||
self._sys_log_level = sys_log_level
|
||||
|
||||
|
||||
SIM_OUTPUT = _SimOutput()
|
||||
|
||||
@@ -261,7 +261,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.11"
|
||||
"version": "3.10.12"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -256,9 +256,11 @@
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "22",
|
||||
"metadata": {},
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"source": [
|
||||
"Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
"Calling `switch.arp.show()` displays the Switch ARP Cache."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -270,13 +272,33 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"switch_1\").sys_log.show()"
|
||||
"network.get_node_by_hostname(\"switch_1\").arp.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "24",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `switch.sys_log.show()` displays the Switch system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "25",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"switch_1\").sys_log.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "26",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Computer/Server Nodes\n",
|
||||
"\n",
|
||||
@@ -285,7 +307,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "25",
|
||||
"id": "27",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -293,26 +315,6 @@
|
||||
"Calling `computer.show()` displays the NICs on the Computer/Server."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "26",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"security_suite\").show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "27",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `computer.arp.show()` displays the Computer/Server ARP Cache."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -322,7 +324,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"security_suite\").arp.show()"
|
||||
"network.get_node_by_hostname(\"security_suite\").show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -330,7 +332,7 @@
|
||||
"id": "29",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Calling `computer.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
"Calling `computer.arp.show()` displays the Computer/Server ARP Cache."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -342,7 +344,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"security_suite\").sys_log.show()"
|
||||
"network.get_node_by_hostname(\"security_suite\").arp.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -350,9 +352,7 @@
|
||||
"id": "31",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Basic Network Comms Check\n",
|
||||
"\n",
|
||||
"We can perform a good old ping to check that Nodes are able to communicate with each other."
|
||||
"Calling `computer.sys_log.show()` displays the Computer/Server system log. By default, only the last 10 log entries are displayed, this can be changed by passing `last_n=<number of log entries>`."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -364,7 +364,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.show(nodes=False, links=False)"
|
||||
"network.get_node_by_hostname(\"security_suite\").sys_log.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -372,7 +372,9 @@
|
||||
"id": "33",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll first ping client_1's default gateway."
|
||||
"## Basic Network Comms Check\n",
|
||||
"\n",
|
||||
"We can perform a good old ping to check that Nodes are able to communicate with each other."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -384,27 +386,27 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.10.1\")"
|
||||
"network.show(nodes=False, links=False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "35",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We'll first ping client_1's default gateway."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "35",
|
||||
"id": "36",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_1\").sys_log.show(15)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "36",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)."
|
||||
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.10.1\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -416,7 +418,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.1\")"
|
||||
"network.get_node_by_hostname(\"client_1\").sys_log.show(15)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -424,7 +426,7 @@
|
||||
"id": "38",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"And finally, we'll ping the web server."
|
||||
"Next, we'll ping the interface of the 192.168.1.0/24 Network on the Router (port 1)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -436,7 +438,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")"
|
||||
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.1\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -444,7 +446,7 @@
|
||||
"id": "40",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To confirm that the ping was received and processed by the web_server, we can view the sys log"
|
||||
"And finally, we'll ping the web server."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -456,45 +458,45 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"web_server\").sys_log.show()"
|
||||
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "42",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To confirm that the ping was received and processed by the web_server, we can view the sys log"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "43",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"web_server\").sys_log.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "44",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced Network Usage\n",
|
||||
"\n",
|
||||
"We can now use the Network to perform some more advanced things."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "43",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "44",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "45",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:"
|
||||
"Let's attempt to prevent client_2 from being able to ping the web server. First, we'll confirm that it can ping the server first..."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -506,13 +508,33 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_2\").sys_log.show()"
|
||||
"network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "47",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If we look at the client_2 sys log we can see that the four ICMP echo requests were sent and four ICMP each replies were received:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "48",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_2\").sys_log.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "49",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we'll add an ACL to block ICMP from 192.168.10.22"
|
||||
]
|
||||
@@ -520,7 +542,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "48",
|
||||
"id": "50",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -540,7 +562,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "49",
|
||||
"id": "51",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
@@ -549,32 +571,12 @@
|
||||
"network.get_node_by_hostname(\"router_1\").acl.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "50",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we attempt (and fail) to ping the web server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "51",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "52",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can check that the ping was actually sent by client_2 by viewing the sys log"
|
||||
"Now we attempt (and fail) to ping the web server"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -586,7 +588,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_2\").sys_log.show()"
|
||||
"network.get_node_by_hostname(\"client_2\").ping(\"192.168.1.12\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -594,7 +596,7 @@
|
||||
"id": "54",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We can check the router sys log to see why the traffic was blocked"
|
||||
"We can check that the ping was actually sent by client_2 by viewing the sys log"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -606,7 +608,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"router_1\").sys_log.show()"
|
||||
"network.get_node_by_hostname(\"client_2\").sys_log.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -614,7 +616,7 @@
|
||||
"id": "56",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now a final check to ensure that client_1 can still ping the web_server."
|
||||
"We can check the router sys log to see why the traffic was blocked"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -625,6 +627,26 @@
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"router_1\").sys_log.show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "58",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now a final check to ensure that client_1 can still ping the web_server."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "59",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"network.get_node_by_hostname(\"client_1\").ping(\"192.168.1.12\")"
|
||||
]
|
||||
@@ -632,7 +654,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "58",
|
||||
"id": "60",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
},
|
||||
|
||||
@@ -2,9 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import os.path
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
from primaite import getLogger
|
||||
@@ -21,8 +19,6 @@ class File(FileSystemItemABC):
|
||||
:ivar Folder folder: The folder in which the file resides.
|
||||
:ivar FileType file_type: The type of the file.
|
||||
:ivar Optional[int] sim_size: The simulated file size.
|
||||
:ivar bool real: Indicates if the file is actually a real file in the Node sim fs output.
|
||||
:ivar Optional[Path] sim_path: The path if the file is real.
|
||||
"""
|
||||
|
||||
folder_id: str
|
||||
@@ -33,12 +29,6 @@ class File(FileSystemItemABC):
|
||||
"The type of File."
|
||||
sim_size: Optional[int] = None
|
||||
"The simulated file size."
|
||||
real: bool = False
|
||||
"Indicates whether the File is actually a real file in the Node sim fs output."
|
||||
sim_path: Optional[Path] = None
|
||||
"The Path if real is True."
|
||||
sim_root: Optional[Path] = None
|
||||
"Root path of the simulation."
|
||||
num_access: int = 0
|
||||
"Number of times the file was accessed in the current step."
|
||||
|
||||
@@ -67,13 +57,6 @@ class File(FileSystemItemABC):
|
||||
if not kwargs.get("sim_size"):
|
||||
kwargs["sim_size"] = kwargs["file_type"].default_size
|
||||
super().__init__(**kwargs)
|
||||
if self.real:
|
||||
self.sim_path = self.sim_root / self.path
|
||||
if not self.sim_path.exists():
|
||||
self.sim_path.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(self.sim_path, mode="a"):
|
||||
pass
|
||||
|
||||
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
|
||||
|
||||
@property
|
||||
@@ -92,8 +75,6 @@ class File(FileSystemItemABC):
|
||||
|
||||
:return: The size of the file in bytes.
|
||||
"""
|
||||
if self.real:
|
||||
return os.path.getsize(self.sim_path)
|
||||
return self.sim_size
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
@@ -127,7 +108,7 @@ class File(FileSystemItemABC):
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}")
|
||||
self.sys_log.info(f"Scanning file {path}")
|
||||
self.visible_health_status = self.health_status
|
||||
return True
|
||||
|
||||
@@ -155,17 +136,8 @@ class File(FileSystemItemABC):
|
||||
return False
|
||||
current_hash = None
|
||||
|
||||
# if file is real, read the file contents
|
||||
if self.real:
|
||||
with open(self.sim_path, "rb") as f:
|
||||
file_hash = hashlib.blake2b()
|
||||
while chunk := f.read(8192):
|
||||
file_hash.update(chunk)
|
||||
|
||||
current_hash = file_hash.hexdigest()
|
||||
else:
|
||||
# otherwise get describe_state dict and hash that
|
||||
current_hash = hashlib.blake2b(json.dumps(self.describe_state(), sort_keys=True).encode()).hexdigest()
|
||||
# otherwise get describe_state dict and hash that
|
||||
current_hash = hashlib.blake2b(json.dumps(self.describe_state(), sort_keys=True).encode()).hexdigest()
|
||||
|
||||
# if the previous hash is None, set the current hash to previous
|
||||
if self.previous_hash is None:
|
||||
@@ -188,7 +160,7 @@ class File(FileSystemItemABC):
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}")
|
||||
self.sys_log.info(f"Repaired file {path}")
|
||||
return True
|
||||
|
||||
def corrupt(self) -> bool:
|
||||
@@ -203,7 +175,7 @@ class File(FileSystemItemABC):
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}")
|
||||
self.sys_log.info(f"Corrupted file {path}")
|
||||
return True
|
||||
|
||||
def restore(self) -> bool:
|
||||
@@ -217,7 +189,7 @@ class File(FileSystemItemABC):
|
||||
|
||||
self.num_access += 1 # file was accessed
|
||||
path = self.folder.name + "/" + self.name
|
||||
self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}")
|
||||
self.sys_log.info(f"Restored file {path}")
|
||||
return True
|
||||
|
||||
def delete(self) -> bool:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional
|
||||
|
||||
@@ -230,7 +229,6 @@ class FileSystem(SimComponent):
|
||||
size: Optional[int] = None,
|
||||
file_type: Optional[FileType] = None,
|
||||
folder_name: Optional[str] = None,
|
||||
real: bool = False,
|
||||
) -> File:
|
||||
"""
|
||||
Creates a File and adds it to the list of files.
|
||||
@@ -239,7 +237,6 @@ class FileSystem(SimComponent):
|
||||
:param size: The size the file takes on disk in bytes.
|
||||
:param file_type: The type of the file.
|
||||
:param folder_name: The folder to add the file to.
|
||||
:param real: "Indicates whether the File is actually a real file in the Node sim fs output."
|
||||
"""
|
||||
if folder_name:
|
||||
# check if file with name already exists
|
||||
@@ -258,8 +255,6 @@ class FileSystem(SimComponent):
|
||||
file_type=file_type,
|
||||
folder_id=folder.uuid,
|
||||
folder_name=folder.name,
|
||||
real=real,
|
||||
sim_path=self.sim_root if real else None,
|
||||
sim_root=self.sim_root,
|
||||
sys_log=self.sys_log,
|
||||
)
|
||||
@@ -368,11 +363,6 @@ class FileSystem(SimComponent):
|
||||
# add file to dst
|
||||
dst_folder.add_file(file)
|
||||
self.num_file_creations += 1
|
||||
if file.real:
|
||||
old_sim_path = file.sim_path
|
||||
file.sim_path = file.sim_root / file.path
|
||||
file.sim_path.parent.mkdir(exist_ok=True)
|
||||
shutil.move(old_sim_path, file.sim_path)
|
||||
|
||||
def copy_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str):
|
||||
"""
|
||||
@@ -401,9 +391,6 @@ class FileSystem(SimComponent):
|
||||
|
||||
dst_folder.add_file(file_copy, force=True)
|
||||
|
||||
if file.real:
|
||||
file_copy.sim_path.parent.mkdir(exist_ok=True)
|
||||
shutil.copy2(file.sim_path, file_copy.sim_path)
|
||||
else:
|
||||
self.sys_log.error(f"Unable to copy file. {src_file_name} does not exist.")
|
||||
|
||||
|
||||
@@ -192,7 +192,6 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
|
||||
# Cannot send Frame as the network interface is not enabled
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def receive_frame(self, frame: Frame) -> bool:
|
||||
"""
|
||||
Receives a network frame on the network interface.
|
||||
@@ -200,7 +199,13 @@ class WirelessNetworkInterface(NetworkInterface, ABC):
|
||||
:param frame: The network frame being received.
|
||||
:return: A boolean indicating whether the frame was successfully received.
|
||||
"""
|
||||
pass
|
||||
if self.enabled:
|
||||
frame.set_sent_timestamp()
|
||||
self.pcap.capture_inbound(frame)
|
||||
self._connected_node.receive_frame(frame, self)
|
||||
return True
|
||||
# Cannot receive Frame as the network interface is not enabled
|
||||
return False
|
||||
|
||||
|
||||
class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC):
|
||||
|
||||
@@ -1378,7 +1378,7 @@ class Node(SimComponent):
|
||||
application_instance.configure(server_ip_address=IPv4Address(ip_address))
|
||||
else:
|
||||
pass
|
||||
|
||||
application_instance.install()
|
||||
if application_instance.name in self.software_manager.software:
|
||||
return True
|
||||
else:
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import ClassVar, Dict
|
||||
|
||||
from primaite.simulator.network.hardware.nodes.host.host_node import HostNode
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
|
||||
|
||||
class Computer(HostNode):
|
||||
@@ -29,4 +32,6 @@ class Computer(HostNode):
|
||||
* Web Browser
|
||||
"""
|
||||
|
||||
SYSTEM_SOFTWARE: ClassVar[Dict] = {**HostNode.SYSTEM_SOFTWARE, "FTPClient": FTPClient}
|
||||
|
||||
pass
|
||||
|
||||
@@ -10,7 +10,6 @@ from primaite.simulator.network.transmission.data_link_layer import Frame
|
||||
from primaite.simulator.system.applications.web_browser import WebBrowser
|
||||
from primaite.simulator.system.services.arp.arp import ARP, ARPPacket
|
||||
from primaite.simulator.system.services.dns.dns_client import DNSClient
|
||||
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
|
||||
from primaite.simulator.system.services.icmp.icmp import ICMP
|
||||
from primaite.simulator.system.services.ntp.ntp_client import NTPClient
|
||||
from primaite.utils.validators import IPV4Address
|
||||
@@ -301,7 +300,6 @@ class HostNode(Node):
|
||||
"HostARP": HostARP,
|
||||
"ICMP": ICMP,
|
||||
"DNSClient": DNSClient,
|
||||
"FTPClient": FTPClient,
|
||||
"NTPClient": NTPClient,
|
||||
"WebBrowser": WebBrowser,
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Set
|
||||
from typing import Any, Dict, Optional, Set
|
||||
|
||||
from primaite.interface.request import RequestResponse
|
||||
from primaite.simulator.core import RequestManager, RequestType
|
||||
@@ -33,6 +33,10 @@ class Application(IOSoftware):
|
||||
"The number of times the application has been executed. Default is 0."
|
||||
groups: Set[str] = set()
|
||||
"The set of groups to which the application belongs."
|
||||
install_duration: int = 2
|
||||
"How long it takes to install the application."
|
||||
install_countdown: Optional[int] = None
|
||||
"The countdown to the end of the installation process. None if not currently installing"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -76,6 +80,12 @@ class Application(IOSoftware):
|
||||
:param timestep: The current timestep of the simulation.
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
if self.operating_state is ApplicationOperatingState.INSTALLING:
|
||||
self.install_countdown -= 1
|
||||
if self.install_countdown <= 0:
|
||||
self.operating_state = ApplicationOperatingState.RUNNING
|
||||
self.health_state_actual = SoftwareHealthState.GOOD
|
||||
self.install_countdown = None
|
||||
|
||||
def pre_timestep(self, timestep: int) -> None:
|
||||
"""Apply pre-timestep logic."""
|
||||
@@ -129,6 +139,7 @@ class Application(IOSoftware):
|
||||
super().install()
|
||||
if self.operating_state == ApplicationOperatingState.CLOSED:
|
||||
self.operating_state = ApplicationOperatingState.INSTALLING
|
||||
self.install_countdown = self.install_duration
|
||||
|
||||
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
|
||||
"""
|
||||
|
||||
@@ -177,4 +177,5 @@ class DoSBot(DatabaseClient):
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
super().apply_timestep(timestep=timestep)
|
||||
self._application_loop()
|
||||
|
||||
@@ -118,14 +118,6 @@ class RansomwareScript(Application):
|
||||
self.sys_log.info(f"{self.name}: Activated!")
|
||||
self.attack_stage = RansomwareAttackStage.ACTIVATE
|
||||
|
||||
def apply_timestep(self, timestep: int) -> None:
|
||||
"""
|
||||
Apply a timestep to the bot, triggering the application loop.
|
||||
|
||||
:param timestep: The timestep value to update the bot's state.
|
||||
"""
|
||||
pass
|
||||
|
||||
def run(self) -> bool:
|
||||
"""Calls the parent classes execute method before starting the application loop."""
|
||||
super().run()
|
||||
|
||||
@@ -126,7 +126,6 @@ class FTPClient(FTPServiceABC):
|
||||
dest_file_name: str,
|
||||
dest_port: Optional[Port] = Port.FTP,
|
||||
session_id: Optional[str] = None,
|
||||
real_file_path: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Send a file to a target IP address.
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import shutil
|
||||
from abc import ABC
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Dict, Optional
|
||||
@@ -55,19 +54,17 @@ class FTPServiceABC(Service, ABC):
|
||||
file_name = payload.ftp_command_args["dest_file_name"]
|
||||
folder_name = payload.ftp_command_args["dest_folder_name"]
|
||||
file_size = payload.ftp_command_args["file_size"]
|
||||
real_file_path = payload.ftp_command_args.get("real_file_path")
|
||||
health_status = payload.ftp_command_args["health_status"]
|
||||
is_real = real_file_path is not None
|
||||
file = self.file_system.create_file(
|
||||
file_name=file_name, folder_name=folder_name, size=file_size, real=is_real
|
||||
file_name=file_name,
|
||||
folder_name=folder_name,
|
||||
size=file_size,
|
||||
)
|
||||
file.health_status = health_status
|
||||
self.sys_log.info(
|
||||
f"{self.name}: Created item in {self.sys_log.hostname}: {payload.ftp_command_args['dest_folder_name']}/"
|
||||
f"{payload.ftp_command_args['dest_file_name']}"
|
||||
)
|
||||
if is_real:
|
||||
shutil.copy(real_file_path, file.sim_path)
|
||||
# file should exist
|
||||
return self.file_system.get_file(file_name=file_name, folder_name=folder_name) is not None
|
||||
except Exception as e:
|
||||
@@ -115,7 +112,6 @@ class FTPServiceABC(Service, ABC):
|
||||
"dest_folder_name": dest_folder_name,
|
||||
"dest_file_name": dest_file_name,
|
||||
"file_size": file.sim_size,
|
||||
"real_file_path": file.sim_path if file.real else None,
|
||||
"health_status": file.health_status,
|
||||
},
|
||||
packet_payload_size=file.sim_size,
|
||||
|
||||
0
src/primaite/utils/cli/__init__.py
Normal file
0
src/primaite/utils/cli/__init__.py
Normal file
171
src/primaite/utils/cli/dev_cli.py
Normal file
171
src/primaite/utils/cli/dev_cli.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import click
|
||||
import typer
|
||||
from rich import print
|
||||
from rich.table import Table
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from primaite import _PRIMAITE_ROOT, PRIMAITE_CONFIG
|
||||
from primaite.simulator import LogLevel
|
||||
from primaite.utils.cli.primaite_config_utils import is_dev_mode, update_primaite_application_config
|
||||
|
||||
dev = typer.Typer()
|
||||
|
||||
PRODUCTION_MODE_MESSAGE = (
|
||||
"\n[green]:rocket::rocket::rocket: "
|
||||
" PrimAITE is running in Production mode "
|
||||
" :rocket::rocket::rocket: [/green]\n"
|
||||
)
|
||||
|
||||
DEVELOPER_MODE_MESSAGE = (
|
||||
"\n[yellow] :construction::construction::construction: "
|
||||
" PrimAITE is running in Development mode "
|
||||
" :construction::construction::construction: [/yellow]\n"
|
||||
)
|
||||
|
||||
|
||||
def dev_mode():
|
||||
"""
|
||||
CLI commands relevant to the dev-mode for PrimAITE.
|
||||
|
||||
The dev-mode contains tools that help with the ease of developing or debugging PrimAITE.
|
||||
|
||||
By default, PrimAITE will be in production mode.
|
||||
|
||||
To enable development mode, use `primaite dev-mode enable`
|
||||
"""
|
||||
|
||||
|
||||
@dev.command()
|
||||
def show():
|
||||
"""Show if PrimAITE is in development mode or production mode."""
|
||||
# print if dev mode is enabled
|
||||
print(DEVELOPER_MODE_MESSAGE if is_dev_mode() else PRODUCTION_MODE_MESSAGE)
|
||||
|
||||
table = Table(title="Current Dev-Mode Settings")
|
||||
table.add_column("Setting", style="cyan")
|
||||
table.add_column("Value", style="default")
|
||||
for setting, value in PRIMAITE_CONFIG["developer_mode"].items():
|
||||
table.add_row(setting, str(value))
|
||||
|
||||
print(table)
|
||||
print("\nTo see available options, use [cyan]`primaite dev-mode --help`[/cyan]\n")
|
||||
|
||||
|
||||
@dev.command()
|
||||
def enable():
|
||||
"""Enable the development mode for PrimAITE."""
|
||||
# enable dev mode
|
||||
PRIMAITE_CONFIG["developer_mode"]["enabled"] = True
|
||||
update_primaite_application_config()
|
||||
print(DEVELOPER_MODE_MESSAGE)
|
||||
|
||||
|
||||
@dev.command()
|
||||
def disable():
|
||||
"""Disable the development mode for PrimAITE."""
|
||||
# disable dev mode
|
||||
PRIMAITE_CONFIG["developer_mode"]["enabled"] = False
|
||||
update_primaite_application_config()
|
||||
print(PRODUCTION_MODE_MESSAGE)
|
||||
|
||||
|
||||
def config_callback(
|
||||
ctx: typer.Context,
|
||||
sys_log_level: Annotated[
|
||||
LogLevel,
|
||||
typer.Option(
|
||||
"--sys-log-level",
|
||||
"-level",
|
||||
click_type=click.Choice(LogLevel._member_names_, case_sensitive=False),
|
||||
help="The level of system logs to output.",
|
||||
show_default=False,
|
||||
),
|
||||
] = None,
|
||||
output_sys_logs: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--output-sys-logs/--no-sys-logs", "-sys/-nsys", help="Output system logs to file.", show_default=False
|
||||
),
|
||||
] = None,
|
||||
output_pcap_logs: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--output-pcap-logs/--no-pcap-logs",
|
||||
"-pcap/-npcap",
|
||||
help="Output network packet capture logs to file.",
|
||||
show_default=False,
|
||||
),
|
||||
] = None,
|
||||
output_to_terminal: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--output-to-terminal/--no-terminal", "-t/-nt", help="Output system logs to terminal.", show_default=False
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
"""Configure the development tools and environment."""
|
||||
if ctx.params.get("sys_log_level") is not None:
|
||||
PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] = ctx.params.get("sys_log_level")
|
||||
print(f"PrimAITE dev-mode config updated sys_log_level={ctx.params.get('sys_log_level')}")
|
||||
|
||||
if output_sys_logs is not None:
|
||||
PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] = output_sys_logs
|
||||
print(f"PrimAITE dev-mode config updated {output_sys_logs=}")
|
||||
|
||||
if output_pcap_logs is not None:
|
||||
PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] = output_pcap_logs
|
||||
print(f"PrimAITE dev-mode config updated {output_pcap_logs=}")
|
||||
|
||||
if output_to_terminal is not None:
|
||||
PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] = output_to_terminal
|
||||
print(f"PrimAITE dev-mode config updated {output_to_terminal=}")
|
||||
|
||||
# update application config
|
||||
update_primaite_application_config()
|
||||
|
||||
|
||||
config_typer = typer.Typer(
|
||||
callback=config_callback,
|
||||
name="config",
|
||||
no_args_is_help=True,
|
||||
invoke_without_command=True,
|
||||
)
|
||||
dev.add_typer(config_typer)
|
||||
|
||||
|
||||
@config_typer.command()
|
||||
def path(
|
||||
directory: Annotated[
|
||||
str,
|
||||
typer.Argument(
|
||||
help="Directory where the system logs and PCAP logs will be output. By default, this will be where the"
|
||||
"root of the PrimAITE repository is located.",
|
||||
show_default=False,
|
||||
),
|
||||
] = None,
|
||||
default: Annotated[
|
||||
bool,
|
||||
typer.Option(
|
||||
"--default",
|
||||
"-root",
|
||||
help="Set PrimAITE to output system logs and pcap logs to the PrimAITE repository root.",
|
||||
),
|
||||
] = None,
|
||||
):
|
||||
"""Set the output directory for the PrimAITE system and PCAP logs."""
|
||||
if default:
|
||||
PRIMAITE_CONFIG["developer_mode"]["output_dir"] = None
|
||||
# update application config
|
||||
update_primaite_application_config()
|
||||
print(
|
||||
f"PrimAITE dev-mode output_dir [cyan]"
|
||||
f"{str(_PRIMAITE_ROOT.parent.parent / 'simulation_output')}"
|
||||
f"[/cyan]"
|
||||
)
|
||||
return
|
||||
|
||||
if directory:
|
||||
PRIMAITE_CONFIG["developer_mode"]["output_dir"] = directory
|
||||
# update application config
|
||||
update_primaite_application_config()
|
||||
print(f"PrimAITE dev-mode output_dir [cyan]{directory}[/cyan]")
|
||||
22
src/primaite/utils/cli/primaite_config_utils.py
Normal file
22
src/primaite/utils/cli/primaite_config_utils.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import PRIMAITE_CONFIG, PRIMAITE_PATHS
|
||||
|
||||
|
||||
def is_dev_mode() -> bool:
|
||||
"""Returns True if PrimAITE is currently running in developer mode."""
|
||||
return PRIMAITE_CONFIG.get("developer_mode", {}).get("enabled", False)
|
||||
|
||||
|
||||
def update_primaite_application_config(config: Optional[Dict] = None) -> None:
|
||||
"""
|
||||
Update the PrimAITE application config file.
|
||||
|
||||
:params: config: Leave empty so that PRIMAITE_CONFIG is used - otherwise provide the Dict
|
||||
"""
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "w") as file:
|
||||
if not config:
|
||||
config = PRIMAITE_CONFIG
|
||||
yaml.dump(config, file)
|
||||
@@ -1,11 +0,0 @@
|
||||
import yaml
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
|
||||
def is_dev_mode() -> bool:
|
||||
"""Returns True if PrimAITE is currently running in developer mode."""
|
||||
if PRIMAITE_PATHS.app_config_file_path.exists():
|
||||
with open(PRIMAITE_PATHS.app_config_file_path, "r") as file:
|
||||
primaite_config = yaml.safe_load(file)
|
||||
return primaite_config["developer_mode"]
|
||||
@@ -1,11 +1,8 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
|
||||
from primaite import getLogger, PRIMAITE_PATHS
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
@@ -13,7 +10,6 @@ from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.container import Network
|
||||
from primaite.simulator.network.hardware.nodes.host.computer import Computer
|
||||
@@ -32,7 +28,6 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path
|
||||
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
@@ -40,21 +35,6 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def set_syslog_output_to_true():
|
||||
"""Will be run before each test."""
|
||||
monkeypatch = MonkeyPatch()
|
||||
monkeypatch.setattr(
|
||||
SIM_OUTPUT,
|
||||
"path",
|
||||
Path(TEST_ASSETS_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")),
|
||||
)
|
||||
monkeypatch.setattr(SIM_OUTPUT, "save_pcap_logs", False)
|
||||
monkeypatch.setattr(SIM_OUTPUT, "save_sys_logs", False)
|
||||
|
||||
yield
|
||||
|
||||
|
||||
class TestService(Service):
|
||||
"""Test Service class"""
|
||||
|
||||
@@ -86,7 +66,10 @@ class TestApplication(Application):
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def uc2_network() -> Network:
|
||||
return arcd_uc2_network()
|
||||
with open(PRIMAITE_PATHS.user_config_path / "example_config" / "data_manipulation.yaml") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
return game.simulation.network
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
|
||||
@@ -38,3 +38,5 @@ def test_rllib_single_agent_compatibility():
|
||||
save_file = Path(tempfile.gettempdir()) / "ray/"
|
||||
algo.save(save_file)
|
||||
assert save_file.exists()
|
||||
|
||||
save_file.unlink() # clean up
|
||||
|
||||
@@ -25,3 +25,4 @@ def test_sb3_compatibility():
|
||||
model.save(save_path)
|
||||
|
||||
assert (save_path).exists()
|
||||
save_path.unlink() # clean up
|
||||
|
||||
@@ -26,6 +26,9 @@ def test_data_manipulation(uc2_network):
|
||||
# First check that the DB client on the web_server can successfully query the users table on the database
|
||||
assert db_connection.query("SELECT")
|
||||
|
||||
db_manipulation_bot.data_manipulation_p_of_success = 1.0
|
||||
db_manipulation_bot.port_scan_p_of_success = 1.0
|
||||
|
||||
# Now we run the DataManipulationBot
|
||||
db_manipulation_bot.attack()
|
||||
|
||||
@@ -59,6 +62,11 @@ def test_application_install_uninstall_on_uc2():
|
||||
_, _, _, _, info = env.step(78)
|
||||
assert "DoSBot" in domcon.software_manager.software
|
||||
|
||||
# installing takes 3 steps so let's wait for 3 steps
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
env.step(0)
|
||||
|
||||
# Test we can now execute the DoSBot app
|
||||
_, _, _, _, info = env.step(81)
|
||||
assert info["agent_actions"]["defender"].response.status == "success"
|
||||
|
||||
11
tests/integration_tests/cli/__init__.py
Normal file
11
tests/integration_tests/cli/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from typing import List
|
||||
|
||||
from typer.testing import CliRunner, Result
|
||||
|
||||
from primaite.cli import app
|
||||
|
||||
|
||||
def cli(args: List[str]) -> Result:
|
||||
"""Pass in a list of arguments and it will return the result."""
|
||||
runner = CliRunner()
|
||||
return runner.invoke(app, args)
|
||||
171
tests/integration_tests/cli/test_dev_cli.py
Normal file
171
tests/integration_tests/cli/test_dev_cli.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pkg_resources
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite import PRIMAITE_CONFIG
|
||||
from primaite.utils.cli.primaite_config_utils import update_primaite_application_config
|
||||
from tests.integration_tests.cli import cli
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def test_setup():
|
||||
"""
|
||||
Setup this test by using the default primaite app config in package
|
||||
"""
|
||||
global PRIMAITE_CONFIG
|
||||
current_config = PRIMAITE_CONFIG.copy() # store the config before test
|
||||
|
||||
pkg_config_path = Path(pkg_resources.resource_filename("primaite", "setup/_package_data/primaite_config.yaml"))
|
||||
|
||||
with open(pkg_config_path, "r") as file:
|
||||
# load from config
|
||||
config_dict = yaml.safe_load(file)
|
||||
|
||||
PRIMAITE_CONFIG["developer_mode"] = config_dict["developer_mode"]
|
||||
|
||||
yield
|
||||
|
||||
PRIMAITE_CONFIG["developer_mode"] = current_config["developer_mode"] # restore config to prevent being yelled at
|
||||
update_primaite_application_config(config=PRIMAITE_CONFIG)
|
||||
|
||||
|
||||
def test_dev_mode_enable_disable():
|
||||
"""Test dev mode enable and disable."""
|
||||
# check defaults
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["enabled"] is False # not enabled by default
|
||||
|
||||
result = cli(["dev-mode", "show"])
|
||||
assert "Production" in result.output # should print that it is in Production mode by default
|
||||
|
||||
result = cli(["dev-mode", "enable"])
|
||||
|
||||
assert "Development" in result.output # should print that it is in Development mode
|
||||
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["enabled"] # config should reflect that dev mode is enabled
|
||||
|
||||
result = cli(["dev-mode", "show"])
|
||||
assert "Development" in result.output # should print that it is in Development mode
|
||||
|
||||
result = cli(["dev-mode", "disable"])
|
||||
|
||||
assert "Production" in result.output # should print that it is in Production mode
|
||||
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["enabled"] is False # config should reflect that dev mode is disabled
|
||||
|
||||
result = cli(["dev-mode", "show"])
|
||||
assert "Production" in result.output # should print that it is in Production mode
|
||||
|
||||
|
||||
def test_dev_mode_config_sys_log_level():
|
||||
"""Check that the system log level can be changed via CLI."""
|
||||
# check defaults
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "DEBUG" # DEBUG by default
|
||||
|
||||
result = cli(["dev-mode", "config", "-level", "WARNING"])
|
||||
|
||||
assert "sys_log_level=WARNING" in result.output # should print correct value
|
||||
|
||||
# config should reflect that log level is WARNING
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "WARNING"
|
||||
|
||||
result = cli(["dev-mode", "config", "--sys-log-level", "INFO"])
|
||||
|
||||
assert "sys_log_level=INFO" in result.output # should print correct value
|
||||
|
||||
# config should reflect that log level is WARNING
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["sys_log_level"] == "INFO"
|
||||
|
||||
|
||||
def test_dev_mode_config_sys_logs_enable_disable():
|
||||
"""Test that the system logs output can be enabled or disabled."""
|
||||
# check defaults
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] is False # False by default
|
||||
|
||||
result = cli(["dev-mode", "config", "--output-sys-logs"])
|
||||
assert "output_sys_logs=True" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_sys_logs is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"]
|
||||
|
||||
result = cli(["dev-mode", "config", "--no-sys-logs"])
|
||||
assert "output_sys_logs=False" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_sys_logs is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] is False
|
||||
|
||||
result = cli(["dev-mode", "config", "-sys"])
|
||||
assert "output_sys_logs=True" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_sys_logs is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"]
|
||||
|
||||
result = cli(["dev-mode", "config", "-nsys"])
|
||||
assert "output_sys_logs=False" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_sys_logs is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_sys_logs"] is False
|
||||
|
||||
|
||||
def test_dev_mode_config_pcap_logs_enable_disable():
|
||||
"""Test that the pcap logs output can be enabled or disabled."""
|
||||
# check defaults
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] is False # False by default
|
||||
|
||||
result = cli(["dev-mode", "config", "--output-pcap-logs"])
|
||||
assert "output_pcap_logs=True" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_pcap_logs is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"]
|
||||
|
||||
result = cli(["dev-mode", "config", "--no-pcap-logs"])
|
||||
assert "output_pcap_logs=False" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_pcap_logs is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] is False
|
||||
|
||||
result = cli(["dev-mode", "config", "-pcap"])
|
||||
assert "output_pcap_logs=True" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_pcap_logs is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"]
|
||||
|
||||
result = cli(["dev-mode", "config", "-npcap"])
|
||||
assert "output_pcap_logs=False" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_pcap_logs is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_pcap_logs"] is False
|
||||
|
||||
|
||||
def test_dev_mode_config_output_to_terminal_enable_disable():
|
||||
"""Test that the output to terminal can be enabled or disabled."""
|
||||
# check defaults
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] is False # False by default
|
||||
|
||||
result = cli(["dev-mode", "config", "--output-to-terminal"])
|
||||
assert "output_to_terminal=True" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_to_terminal is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"]
|
||||
|
||||
result = cli(["dev-mode", "config", "--no-terminal"])
|
||||
assert "output_to_terminal=False" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_to_terminal is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] is False
|
||||
|
||||
result = cli(["dev-mode", "config", "-t"])
|
||||
assert "output_to_terminal=True" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_to_terminal is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"]
|
||||
|
||||
result = cli(["dev-mode", "config", "-nt"])
|
||||
assert "output_to_terminal=False" in result.output # should print correct value
|
||||
|
||||
# config should reflect that output_to_terminal is True
|
||||
assert PRIMAITE_CONFIG["developer_mode"]["output_to_terminal"] is False
|
||||
@@ -83,7 +83,7 @@ def test_sometech_webserver_cannot_access_ftp_on_sometech_storage_server():
|
||||
some_tech_storage_srv.file_system.create_file(file_name="test.png")
|
||||
|
||||
web_server: Server = network.get_node_by_hostname("some_tech_web_srv")
|
||||
|
||||
web_server.software_manager.install(FTPClient)
|
||||
web_ftp_client: FTPClient = web_server.software_manager.software["FTPClient"]
|
||||
|
||||
assert not web_ftp_client.request_file(
|
||||
|
||||
@@ -101,7 +101,7 @@ def test_database_client_native_connection_query(uc2_network):
|
||||
"""Tests DB query across the network returns HTTP status 200 and date."""
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
|
||||
|
||||
db_client.connect()
|
||||
assert db_client.query(sql="SELECT")
|
||||
assert db_client.query(sql="INSERT")
|
||||
|
||||
@@ -222,6 +222,7 @@ def test_database_client_cannot_query_offline_database_server(uc2_network):
|
||||
|
||||
web_server: Server = uc2_network.get_node_by_hostname("web_server")
|
||||
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
|
||||
db_client.connect()
|
||||
assert len(db_client.client_connections)
|
||||
|
||||
# Establish a new connection to the DatabaseService
|
||||
|
||||
@@ -57,20 +57,6 @@ def test_simulated_file_check_hash(file_system):
|
||||
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented")
|
||||
def test_real_file_check_hash(file_system):
|
||||
file: File = file_system.create_file(file_name="test_file.txt", real=True)
|
||||
|
||||
file.check_hash()
|
||||
assert file.health_status == FileSystemItemHealthStatus.GOOD
|
||||
# change file content
|
||||
with open(file.sim_path, "a") as f:
|
||||
f.write("get hacked scrub lol xD\n")
|
||||
|
||||
file.check_hash()
|
||||
assert file.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
|
||||
def test_file_corrupt_repair_restore(file_system):
|
||||
"""Test the ability to corrupt and repair files."""
|
||||
file: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder")
|
||||
|
||||
@@ -191,7 +191,7 @@ def test_copy_file(file_system):
|
||||
file_system.create_folder(folder_name="src_folder")
|
||||
file_system.create_folder(folder_name="dst_folder")
|
||||
|
||||
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True)
|
||||
file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder")
|
||||
assert file_system.num_file_creations == 1
|
||||
original_uuid = file.uuid
|
||||
|
||||
|
||||
@@ -132,21 +132,3 @@ def test_simulated_folder_check_hash(file_system):
|
||||
file.sim_size = 0
|
||||
folder.check_hash()
|
||||
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="NODE_FILE_CHECKHASH not implemented")
|
||||
def test_real_folder_check_hash(file_system):
|
||||
folder: Folder = file_system.create_folder(folder_name="test_folder")
|
||||
file_system.create_file(file_name="test_file.txt", folder_name="test_folder", real=True)
|
||||
|
||||
folder.check_hash()
|
||||
assert folder.health_status == FileSystemItemHealthStatus.GOOD
|
||||
# change simulated file size
|
||||
file = folder.get_file(file_name="test_file.txt")
|
||||
|
||||
# change file content
|
||||
with open(file.sim_path, "a") as f:
|
||||
f.write("get hacked scrub lol xD\n")
|
||||
|
||||
folder.check_hash()
|
||||
assert folder.health_status == FileSystemItemHealthStatus.CORRUPT
|
||||
|
||||
@@ -2,10 +2,20 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from primaite import PRIMAITE_CONFIG
|
||||
from primaite.simulator import LogLevel, SIM_OUTPUT
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def override_dev_mode_temporarily():
|
||||
"""Temporarily turn off dev mode for this test."""
|
||||
primaite_dev_mode = PRIMAITE_CONFIG["developer_mode"]["enabled"]
|
||||
PRIMAITE_CONFIG["developer_mode"]["enabled"] = False
|
||||
yield # run tests
|
||||
PRIMAITE_CONFIG["developer_mode"]["enabled"] = primaite_dev_mode
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def syslog() -> SysLog:
|
||||
return SysLog(hostname="test")
|
||||
|
||||
Reference in New Issue
Block a user