Merge branch 'dev' into feature/2041_2042-Add-NTP-Services

This commit is contained in:
Nick Todd
2023-12-04 11:23:35 +00:00
101 changed files with 3803 additions and 1047 deletions

View File

@@ -6,6 +6,9 @@ trigger:
- bugfix/*
- release/*
pr:
autoCancel: true
drafts: false
parameters:
# https://stackoverflow.com/a/70046417
- name: matrix
@@ -85,6 +88,33 @@ stages:
primaite setup
displayName: 'Perform PrimAITE Setup'
- task: UseDotNet@2
displayName: 'Install dotnet dependencies'
inputs:
packageType: 'sdk'
version: '2.1.x'
- script: |
pytest -n auto
displayName: 'Run tests'
coverage run -m --source=primaite pytest -v -o junit_family=xunit2 --junitxml=junit/test-results.xml --cov-fail-under=80
coverage xml -o coverage.xml -i
coverage html -d htmlcov -i
displayName: 'Run tests and code coverage'
- task: PublishTestResults@2
condition: succeededOrFailed()
inputs:
testRunner: JUnit
testResultsFiles: 'junit/**.xml'
testRunTitle: 'Publish test results'
- publish: $(System.DefaultWorkingDirectory)/htmlcov/
# publish the html report - so we can debug the coverage if needed
condition: ${{ item.every_time }} # should only be run once
artifact: coverage_report
- task: PublishCodeCoverageResults@2
# publish the code coverage so it can be viewed in the run coverage page
condition: ${{ item.every_time }} # should only be run once
inputs:
codeCoverageTool: Cobertura
summaryFileLocation: '$(System.DefaultWorkingDirectory)/**/coverage.xml'

2
.gitignore vendored
View File

@@ -37,6 +37,7 @@ pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
junit/
htmlcov/
.tox/
.nox/
@@ -154,3 +155,4 @@ simulation_output/
benchmark/output
# src/primaite/notebooks/scratch.ipynb
src/primaite/notebooks/scratch.py
sandbox.py

View File

@@ -31,15 +31,17 @@ SessionManager.
- `DatabaseClient` and `DatabaseService` created to allow emulation of database actions
- Ability for `DatabaseService` to backup its data to another server via FTP and restore data from backup
- Red Agent Services:
- Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database)
- Data Manipulator Bot - A red agent service which sends a payload to a target machine. (By default this payload is a SQL query that breaks a database). The attack runs in stages with a random, configurable probability of succeeding.
- `DataManipulationAgent` runs the Data Manipulator Bot according to a configured start step, frequency and variance.
- DNS Services: `DNSClient` and `DNSServer`
- FTP Services: `FTPClient` and `FTPServer`
- HTTP Services: `WebBrowser` to simulate a web client and `WebServer`
- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off
- NTP Services: `NTPClient` and `NTPServer`
### Removed
- Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol`
- Removed legacy training modules, they are replaced by the new ARCD GATE dependency
- Removed legacy training modules
- Removed tests for legacy code

View File

@@ -22,8 +22,6 @@ PrimAITE presents the following features:
- Routers with traffic routing and firewall capabilities
- Integration with ARCD GATE for agent training
- Support for multiple agents, each having their own customisable observation space, action space, and reward function definition, and either deterministic or RL-directed behaviour
## Getting Started with PrimAITE

View File

@@ -30,7 +30,7 @@ PrimAITE incorporates the following features:
- Uses the concept of Information Exchange Requirements (IERs) to model background pattern of life and adversarial behaviour;
- An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP address, destination IP address, protocol and port);
- Application of traffic to the links of the platform / system laydown adheres to the ACL ruleset;
- Presents both an OpenAI gym and Ray RLLib interface to the environment, allowing integration with any compliant defensive agents;
- Presents both a Gymnasium and Ray RLLib interface to the environment, allowing integration with any compliant defensive agents;
- Allows for the saving and loading of trained defensive agents;
- Stochastic adversarial agent behaviour;
- Full capture of discrete logs relating to agent training or evaluation (system state, agent actions taken, instantaneous and average reward for every step of every episode);
@@ -40,18 +40,18 @@ PrimAITE incorporates the following features:
Architecture
^^^^^^^^^^^^
PrimAITE is a Python application and is therefore Operating System agnostic. The OpenAI gym and Ray RLLib frameworks are employed to provide an interface and source for AI agents. Configuration of PrimAITE is achieved via included YAML files which support full control over the platform / system laydown being modelled, background pattern of life, adversarial (red agent) behaviour, and step and episode count. NetworkX based nodes and links host Python classes to present attributes and methods, and hence a more representative platform / system can be modelled within the simulation.
PrimAITE is a Python application and is therefore Operating System agnostic. The Gymnasium and Ray RLLib frameworks are employed to provide an interface and source for AI agents. Configuration of PrimAITE is achieved via included YAML files which support full control over the platform / system laydown being modelled, background pattern of life, adversarial (red agent) behaviour, and step and episode count. NetworkX based nodes and links host Python classes to present attributes and methods, and hence a more representative platform / system can be modelled within the simulation.
Training & Evaluation Capability
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
PrimAITE provides a training and evaluation capability to AI agents in the context of cyber-attack, via its OpenAI Gym and RLLib compliant interface. Scenarios can be constructed to reflect platform / system laydowns consisting of any configuration of nodes (e.g. PCs, servers, switches etc.) and network links between them. All nodes can be configured to model services (and their status) and the traffic loading between them over the network links. Traffic loading is broken down into a per service granularity, relating directly to a protocol (e.g. Service A would be configured as a TCP service, and TCP traffic then flows between instances of Service A under the direction of a tailored IER). Highlights of PrimAITEs training and evaluation capability are:
PrimAITE provides a training and evaluation capability to AI agents in the context of cyber-attack, via its Gymnasium and RLLib compliant interface. Scenarios can be constructed to reflect platform / system laydowns consisting of any configuration of nodes (e.g. PCs, servers, switches etc.) and network links between them. All nodes can be configured to model services (and their status) and the traffic loading between them over the network links. Traffic loading is broken down into a per service granularity, relating directly to a protocol (e.g. Service A would be configured as a TCP service, and TCP traffic then flows between instances of Service A under the direction of a tailored IER). Highlights of PrimAITEs training and evaluation capability are:
- The scenario is not bound to a representation of any platform, system, or technology;
- Fully configurable (network / system laydown, IERs, node pattern-of-life, ACL, number of episodes, steps per episode) and repeatable to suit the requirements of AI agents;
- Can integrate with any OpenAI Gym or RLLib compliant AI agent.
- Can integrate with any Gymnasium or RLLib compliant AI agent.
Use of PrimAITE default scenarios within ARCD is supported by a “Use Case Profile” tailored to the scenario.
@@ -75,7 +75,7 @@ Logs are available in CSV format and provide coverage of the above data for ever
What is PrimAITE built with
---------------------------
* `OpenAI's Gym <https://gym.openai.com/>`_ is used as the basis for AI blue agent interaction with the PrimAITE environment
* `Gymnasium <https://gymnasium.farama.org/>`_ is used as the basis for AI blue agent interaction with the PrimAITE environment
* `Networkx <https://github.com/networkx/networkx>`_ is used as the underlying data structure used for the PrimAITE environment
* `Stable Baselines 3 <https://github.com/DLR-RM/stable-baselines3>`_ is used as a default source of RL algorithms (although PrimAITE is not limited to SB3 agents)
* `Ray RLlib <https://github.com/ray-project/ray>`_ is used as an additional source of RL algorithms
@@ -107,7 +107,6 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/primaite_session
source/simulation
source/game_layer
source/custom_agent
source/config
.. toctree::

View File

@@ -18,7 +18,6 @@ PrimAITE provides the following features:
* Highly configurable network hosts, including definition of software, file system, and network interfaces,
* Realistic network traffic simulation, including address and sending packets via internet protocols like TCP, UDP, ICMP, etc.
* Routers with traffic routing and firewall capabilities
* Interfaces with ARCD GATE to allow training of agents
* Simulation of customisable deterministic agents
* Support for multiple agents, each having their own customisable observation space, action space, and reward function definition.
@@ -148,7 +147,7 @@ The game layer is built on top of the simulator and it consumes the simulation a
Observation Spaces
******************
The observation space provides the blue agent with information about the current status of nodes and links.
PrimAITE builds on top of Gym Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationsHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted.
PrimAITE builds on top of Gymnasium Spaces to create an observation space that is easily configurable for users. It's made up of components which are managed by the :py:class:`primaite.environment.observations.ObservationsHandler`. Each training scenario can define its own observation space, and the user can choose which information to inlude, and how it should be formatted.
NodeLinkTable component
-----------------------
For example, the :py:class:`primaite.environment.observations.NodeLinkTable` component represents the status of nodes and links as a ``gym.spaces.Box`` with an example format shown below:
@@ -279,7 +278,7 @@ The game layer is built on top of the simulator and it consumes the simulation a
3. Any (Agent can take both node-based and ACL-based actions)
The choice of action space used during a training session is determined in the config_[name].yaml file.
**Node-Based**
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is an OpenAI Gym spaces.Discrete type, as follows:
The agent is able to influence the status of nodes by switching them off, resetting, or patching operating systems and services. In this instance, the action space is a Gymnasium spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3,x4] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, num nodes] - Node ID (0 = nothing, node ID)
@@ -287,7 +286,7 @@ The game layer is built on top of the simulator and it consumes the simulation a
* [0, 3] - Action on property (0 = nothing, 1 = on / scan, 2 = off / repair, 3 = reset / patch / restore)
* [0, num services] - Resolves to service ID (0 = nothing, resolves to service)
**Access Control List**
The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an OpenAI spaces.Discrete type, as follows:
The blue agent is able to influence the configuration of the Access Control List rule set (which implements a system-wide firewall). In this instance, the action space is an Gymnasium spaces.Discrete type, as follows:
* Dictionary item {... ,1: [x1, x2, x3, x4, x5, x6] ...}
The placeholders inside the list under the key '1' mean the following:
* [0, 2] - Action (0 = do nothing, 1 = create rule, 2 = delete rule)

View File

@@ -1,13 +1,80 @@
Primaite v3 config
******************
PrimAITE uses a single configuration file to define a cybersecurity scenario. This includes the computer network and multiple agents. There are three main sections: training_config, game, and simulation.
PrimAITE uses a single configuration file to define everything needed to train and evaluate an RL policy in a custom cybersecurity scenario. This includes the configuration of the network, the scripted or trained agents that interact with the network, as well as settings that define how to perform training in Stable Baselines 3 or Ray RLLib.
The entire config is used by the ``PrimaiteSession`` object for users who wish to let PrimAITE handle the agent definition and training. If you wish to define custom agents and control the training loop yourself, you can use the config with the ``PrimaiteGame``, and ``PrimaiteGymEnv`` objects instead. That way, only the network configuration and agent setup parts of the config are used, and the training section is ignored.
The simulation section describes the simulated network environment with which the agetns interact.
Configurable items
==================
The game section describes the agents and their capabilities. Each agent has a unique type and is associated with a team (GREEN, RED, or BLUE). Each agent has a configurable observation space, action space, and reward function.
``training_config``
-------------------
This section allows selecting which training framework and algorithm to use, and set some training hyperparameters.
The training_config section describes the training parameters for the learning agents. This includes the number of episodes, the number of steps per episode, and the number of steps before the agents start learning. The training_config section also describes the learning algorithm used by the agents. The learning algorithm is specified by the name of the algorithm and the hyperparameters for the algorithm. The hyperparameters are specific to each algorithm and are described in the documentation for each algorithm.
``io_settings``
---------------
This section configures how the ``PrimaiteSession`` saves data.
.. only:: comment
This needs a bit of refactoring so I haven't written extensive documentation about the config yet.
``game``
--------
This section defines high-level settings that apply across the game, currently it's used to help shape the action and observation spaces by restricting which ports and internet protocols should be considered. Here, users can also set the maximum number of steps in an episode.
``agents``
----------
Agents can be scripted (deterministic and stochastic), or controlled by a reinforcement learning algorithm. Not to be confused with an RL agent, the term agent here is used to refer to an entity that sends requests to the simulated network. In this part of the config, each agent's action space, observation space, and reward function can be defined. All three are defined in a modular way.
**type**: Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour.
**team**: Specifies if the agent is malicious (RED), benign (GREEN), or defensive (BLUE). Currently this value is not used for anything.
**observation space:**
* ``type``: selects which python class from the ``primaite.game.agent.observation`` module is used for the overall observation structure.
* ``options``: allows configuring the chosen observation type. The ``UC2BlueObservation`` should be used for RL Agents.
* ``num_services_per_node``, ``num_folders_per_node``, ``num_files_per_folder``, ``num_nics_per_node`` all define the shape of the observation space. The size and shape of the obs space must remain constant, but the number of files, folders, ACL rules, and other components can change within an episode. Therefore padding is performed and these options set the size of the obs space.
* ``nodes``: list of nodes that will be present in this agent's observation space. The ``node_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. Each node can also be configured with services, and files that should be monitored.
* ``links``: list of links that will be present in this agent's observation space. The ``link_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config.
* ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's ACL table should be used. ``ip_address_order`` sets the encoding of ip addresses as integers within the observation space.
**action space:**
The action space is configured to be made up of individual action types. Once configured, the agent can select an action type and some optional action parameters at every step. For example: The ``NODE_SERVICE_SCAN`` action takes the parameters ``node_id`` and ``service_id``.
Description of configurable items:
* ``action_list``: a list of action modules. The options are listed in the ``primaite.game.agent.actions`` module.
* ``action_map``: (optional). Restricts the possible combinations of action type / action parameter values to reduce the overall size of the action space. By default, every possible combination of actions and parameters will be assigned an integer for the agent's ``MultiDiscrete`` action space. Instead, the ``action_map`` allows you to list the actions corresponding to each integer in the ``MultiDiscrete`` space.
* ``options``: Options that apply too all action components.
* ``nodes``: list the nodes that the agent can act on, the order of this list defines the mapping between nodes and ``node_id`` integers.
* ``max_folders_per_node``, ``max_files_per_folder``, ``max_services_per_node``, ``max_nics_per_node``, ``max_acl_rules`` all are used to define the size of the action space.
**reward function:**
Similar to action space, this is defined as a list of components.
Description of configurable items:
* ``reward_components`` a list of reward components from the ``primaite.game.agent.reward`` module.
* ``weight``: relative importance of this reward component. The total reward for a step is a weighted sum of all reward components.
* ``options``: list of options passed to the reward component during initialisation, the exact options required depend on the reward component.
**agent_settings**:
Settings passed to the agent during initialisation. These depend on the agent class.
``simulation``
--------------
In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents.
At the top level of the network are ``nodes`` and ``links``.
**nodes:**
* ``type``: one of ``router``, ``switch``, ``computer``, or ``server``, this affects what other sub-options should be defined.
* ``hostname`` - a non-unique name used for logging and outputs.
* ``num_ports`` (optional, routers and switches only): number of network interfaces present on the device.
* ``ports`` (optional, routers and switches only): configuration for each network interface, including IP address and subnet mask.
* ``acl`` (Router only): Define the ACL rules at each index of the ACL on the router. the possible options are: ``action`` (PERMIT or DENY), ``src_port``, ``dst_port``, ``protocol``, ``src_ip``, ``dst_ip``. Any options left blank default to none which usually means that it will apply across all options. For example leaving ``src_ip`` blank will apply the rule to all IP addresses.
* ``services`` (computers and servers only): a list of services to install on the node. They must define a ``ref``, ``type``, and ``options`` that depend on which ``type`` was selected.
* ``applications`` (computer and servers only): Similar to services. A list of application to install on the node.
* ``nics`` (computers and servers only): If the node has multiple networking devices, the second, third, fourth, etc... must be defined here with an ``ip_address`` and ``subnet_mask``.
**links:**
* ``ref``: unique identifier for this link
* ``endpoint_a_ref``: Reference to the node at the first end of the link
* ``endpoint_a_port``: The ethernet port or switch port index of the second node
* ``endpoint_b_ref``: Reference to the node at the second end of the link
* ``endpoint_b_port``: The ethernet port or switch port index on the second node

View File

@@ -1,14 +0,0 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Custom Agents
=============
Integrating a user defined blue agent
*************************************
.. note::
PrimAITE uses ARCD GATE for agent integration. In order to use a custom agent with PrimAITE, you must integrate it with ARCD GATE. Please look at the ARCD GATE documentation for more information.

View File

@@ -4,9 +4,9 @@ PrimAITE Game layer
The Primaite codebase consists of two main modules:
* ``simulator``: The simulation logic including the network topology, the network state, and behaviour of various hardware and software classes.
* ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules, including ARCD GATE.
* ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules.
These two components have been decoupled to allow the agent training code in ARCD GATE to be reused with other simulators. The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API. The game layer communicates with ARCD gate using the `Farama Gymnasium Spaces API <https://gymnasium.farama.org/api/spaces/>`_.
The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API.
..
TODO: write up these APIs and link them here.
@@ -20,13 +20,14 @@ The game layer is responsible for managing agents and getting them to interface
PrimAITE Session
^^^^^^^^^^^^^^^
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. It also sends messages to ARCD GATE to perform reinforcement learning. ``PrimaiteSession`` keeps track of multiple agents of different types.
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
Agents
^^^^^^
All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types:
* RL agents action during each step is decided by an RL algorithm which lives inside of ARCD GATE. The agent within PrimAITE just acts to format and forward actions decided by an RL policy.
* RL agents action during each step is decided by an appropriate RL algorithm. The agent within PrimAITE just acts to format and forward actions decided by an RL policy.
* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed will be settable.
..

View File

@@ -87,22 +87,7 @@ Install PrimAITE
pip install path\to\your\primaite.whl
5. Install ARCD GATE from wheel file
.. code-block:: bash
:caption: Unix
pip install path/to/your/arcd_gate-0.1.0-py3-none-any.whl
.. code-block:: powershell
:caption: Windows (Powershell)
pip install path\to\your\arcd_gate-0.1.0-py3-none-any.whl
6. Perform the PrimAITE setup
5. Perform the PrimAITE setup
.. code-block:: bash
:caption: Unix
@@ -153,17 +138,4 @@ of your choice:
pip install -e .[dev]
4. Install ARCD GATE from wheel file
.. code-block:: bash
:caption: Unix
pip install GATE/arcd_gate-0.1.0-py3-none-any.whl
.. code-block:: powershell
:caption: Windows (Powershell)
pip install GATE\arcd_gate-0.1.0-py3-none-any.whl
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).

View File

@@ -74,8 +74,8 @@ Glossary
Laydown
The laydown is a file which defines the training scenario. It contains the network topology, firewall rules, services, protocols, and details about green and red agent behaviours.
Gym
PrimAITE uses the Gym reinforcement learning framework API to create a training environment and interface with RL agents. Gym defines a common way of creating observations, actions, and rewards.
Gymnasium
PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards.
User app home
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>` on linux/darwin and `C:\Users\<username>\primaite\<version>` on Windows.
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>` on linux/darwin and `C:\\Users\\<username>\\primaite\\<version>` on Windows.

View File

@@ -5,12 +5,12 @@
Request System
==============
``SimComponent`` in the simulation are decoupled from the agent training logic. However, they still need a managed means of accepting requests to perform actions. For this, they use ``RequestManager`` and ``RequestType``.
``SimComponent`` objects in the simulation are decoupled from the agent training logic. However, they still need a managed means of accepting requests to perform actions. For this, they use ``RequestManager`` and ``RequestType``.
Just like other aspects of SimComponent, the request typess are not managed centrally for the whole simulation, but instead they are dynamically created and updated based on the nodes, links, and other components that currently exist. This was achieved in the following way:
Just like other aspects of SimComponent, the request types are not managed centrally for the whole simulation, but instead they are dynamically created and updated based on the nodes, links, and other components that currently exist. This was achieved in the following way:
- API
An ``RequestType`` contains two elements:
A ``RequestType`` contains two elements:
1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as `['network', 'node', '<node-uuid>', 'service', '<service-uuid>', 'restart']`.
2. ``context`` - optional extra information that can be used to decide how to process the request. This is formatted as a dictionary. For example, if the request requires authentication, the context can include information about the user that initiated the request to decide if their permissions are sufficient.

View File

@@ -23,8 +23,3 @@ Contents
simulation_components/network/network
simulation_components/system/internal_frame_processing
simulation_components/system/software
simulation_components/system/data_manipulation_bot
simulation_components/system/database_client_server
simulation_components/system/dns_client_server
simulation_components/system/ftp_client_server
simulation_components/system/web_browser_and_web_server_service

View File

@@ -109,6 +109,67 @@ e.g.
instant_start_node = Node(hostname="client", start_up_duration=0, shut_down_duration=0)
instant_start_node.power_on() # node will still need to be powered on
.. _Node Start up and Shut down:
---------------------------
Node Start up and Shut down
---------------------------
Nodes are powered on and off over multiple timesteps. By default, the node ``start_up_duration`` and ``shut_down_duration`` is 3 timesteps.
Example code where a node is turned on:
.. code-block:: python
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
node = Node(hostname="pc_a")
assert node.operating_state is NodeOperatingState.OFF # By default, node is instantiated in an OFF state
node.power_on() # power on the node
assert node.operating_state is NodeOperatingState.BOOTING # node is booting up
for i in range(node.start_up_duration + 1):
# apply timestep until the node start up duration
node.apply_timestep(timestep=i)
assert node.operating_state is NodeOperatingState.ON # node is in ON state
If the node needs to be instantiated in an on state:
.. code-block:: python
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
node = Node(hostname="pc_a", operating_state=NodeOperatingState.ON)
assert node.operating_state is NodeOperatingState.ON # node is in ON state
Setting ``start_up_duration`` and/or ``shut_down_duration`` to ``0`` will allow for the ``power_on`` and ``power_off`` methods to be completed instantly without applying timesteps:
.. code-block:: python
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
node = Node(hostname="pc_a", start_up_duration=0, shut_down_duration=0)
assert node.operating_state is NodeOperatingState.OFF # node is in OFF state
node.power_on()
assert node.operating_state is NodeOperatingState.ON # node is in ON state
node.power_off()
assert node.operating_state is NodeOperatingState.OFF # node is in OFF state
------------------
Network Interfaces
------------------
@@ -357,7 +418,6 @@ Creating the four nodes results in:
2023-08-08 15:50:08,359 INFO: Connected NIC 84:20:7c:ec:a5:c6/192.168.0.13
---------------
Create Switches
---------------

View File

@@ -16,43 +16,125 @@ The bot is intended to simulate a malicious actor carrying out attacks like:
- Dropping tables
- Deleting records
- Modifying data
On a database server by abusing an application's trusted database connectivity.
on a database server by abusing an application's trusted database connectivity.
The bot performs attacks in the following stages to simulate the real pattern of an attack:
- Logon - *The bot gains credentials and accesses the node.*
- Port Scan - *The bot finds accessible database servers on the network.*
- Attacking - *The bot delivers the payload to the discovered database servers.*
Each of these stages has a random, configurable probability of succeeding (by default 10%). The bot can also be configured to repeat the attack once complete.
Usage
-----
- Create an instance and call ``configure`` to set:
- Target database server IP
- Database password (if needed)
- SQL statement payload
- Target database server IP
- Database password (if needed)
- SQL statement payload
- Probabilities for succeeding each of the above attack stages
- Call ``run`` to connect and execute the statement.
The bot handles connecting, executing the statement, and disconnecting.
In a simulation, the bot can be controlled by using ``DataManipulationAgent`` which calls ``run`` on the bot at configured timesteps.
Example
-------
.. code-block:: python
client_1 = Computer(
hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", default_gateway="192.168.10.1"
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1"
operating_state=NodeOperatingState.ON # initialise the computer in an ON state
)
client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
data_manipulation_bot.run()
This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table.
Example with ``DataManipulationAgent``
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
If not using the data manipulation bot manually, it needs to be used with a data manipulation agent. Below is an example section of configuration file for setting up a simulation with data manipulation bot and agent.
.. code-block:: yaml
game_config:
# ...
agents:
- ref: data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
applications:
- application_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 25
frequency: 20
variance: 5
# ...
simulation:
network:
nodes:
- ref: client_1
type: computer
# ... additional configuration here
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
Implementation
--------------
The bot extends ``DatabaseClient`` and leverages its connectivity.
- Uses the Application base class for lifecycle management.
- Credentials and target IP set via ``configure``.
- Credentials, target IP and other options set via ``configure``.
- ``run`` handles connecting, executing statement, and disconnecting.
- SQL payload executed via ``query`` method.
- Results in malicious SQL being executed on remote database server.

View File

@@ -45,17 +45,14 @@ Key features
^^^^^^^^^^^^
- Connects to the ``DatabaseService`` via the ``SoftwareManager``.
- Handles connecting and disconnecting.
- Executes SQL queries and retrieves result sets.
- Handles connecting, querying, and disconnecting.
- Provides a simple ``query`` method for running SQL.
Usage
^^^^^
- Initialise with server IP address and optional password.
- Connect to the ``DatabaseService`` with ``connect``.
- Execute SQL queries via ``query``.
- Retrieve results in a dictionary.
- Disconnect when finished.
@@ -70,6 +67,5 @@ Implementation
- Leverages ``SoftwareManager`` for sending payloads over the network.
- Connect and disconnect methods manage sessions.
- Provides easy interface for applications to query database.
- Payloads serialised as dictionaries for transmission.
- Extends base Application class.

View File

@@ -77,6 +77,7 @@ Dependencies
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
Example peer to peer network
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -85,10 +86,18 @@ Example peer to peer network
net = Network()
pc1 = Computer(hostname="pc1", ip_address="120.10.10.10", subnet_mask="255.255.255.0")
srv = Server(hostname="srv", ip_address="120.10.10.20", subnet_mask="255.255.255.0")
pc1.power_on()
srv.power_on()
pc1 = Computer(
hostname="pc1",
ip_address="120.10.10.10",
subnet_mask="255.255.255.0",
operating_state=NodeOperatingState.ON # initialise the computer in an ON state
)
srv = Server(
hostname="srv",
ip_address="120.10.10.20",
subnet_mask="255.255.255.0",
operating_state=NodeOperatingState.ON # initialise the server in an ON state
)
net.connect(pc1.ethernet_port[1], srv.ethernet_port[1])
Install the FTP Server

View File

@@ -6,14 +6,45 @@
Software
========
-------------
Base Software
-------------
All software which inherits ``IOSoftware`` installed on a node will not work unless the node has been turned on.
See :ref:`Node Start up and Shut down`
.. code-block:: python
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.services.web_server.web_server import WebServer
node = Node(hostname="pc_a", start_up_duration=0, shut_down_duration=0)
node.power_on()
assert node.operating_state is NodeOperatingState.ON
node.software_manager.install(WebServer)
web_server: WebServer = node.software_manager.software.get("WebServer")
assert web_server.operating_state is ServiceOperatingState.RUNNING # service is immediately ran after install
node.power_off()
assert node.operating_state is NodeOperatingState.OFF
assert web_server.operating_state is ServiceOperatingState.STOPPED # service stops when node is powered off
node.power_on()
assert node.operating_state is NodeOperatingState.ON
assert web_server.operating_state is ServiceOperatingState.RUNNING # service turned back on when node is powered on
Contents
########
Services, Processes and Applications:
#####################################
.. toctree::
:maxdepth: 8
:maxdepth: 2
database_client_server
data_manipulation_bot

View File

@@ -5,9 +5,9 @@
Simulation State
==============
``SimComponent`` in the simulation have a method called ``describe_state`` which returns a dictionary of the state of the component. This is used to report pertinent data that could impact agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` reports not only it's own attributes in the state but also that of its child components. I.e. a computer node will report the state of its ``FileSystem``, and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling childrens' own ``describe_state`` methods.
``SimComponent`` objects in the simulation have a method called ``describe_state`` which return a dictionary of the state of the component. This is used to report pertinent data that could impact an agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` object reports not only its own attributes in the state but also those of its child components. I.e. a computer node will report the state of its ``FileSystem`` and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling the childrens' own ``describe_state`` methods.
The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then pass the state to the agents once per simulation step. For this reason, all ``SimComponent`` must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``.
The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then passes the state to the agents once per simulation step. For this reason, all ``SimComponent`` objetcs must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``.
This code snippet demonstrates how the state information is defined within the ``SimComponent`` class:

View File

@@ -1 +1 @@
3.0.0a1
3.0.0b2dev

View File

@@ -1,5 +1,5 @@
training_config:
rl_framework: RLLIB_single_agent
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
@@ -36,31 +36,26 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_2
applications:
- application_ref: client_2_web_browser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
start_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -69,38 +64,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DELETE"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -110,9 +87,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -562,17 +540,25 @@ simulation:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
0:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
@@ -609,7 +595,7 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
@@ -630,6 +616,10 @@ simulation:
services:
- ref: database_service
type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- ref: database_ftp_client
type: FTPClient
- ref: backup_server
type: server
@@ -640,7 +630,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
type: FTPServer
- ref: security_suite
type: server
@@ -661,9 +651,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.8
data_manipulation_p_of_success: 0.8
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
@@ -677,10 +673,14 @@ simulation:
applications:
- ref: client_2_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1

View File

@@ -1,14 +1,10 @@
training_config:
rl_framework: RLLIB_single_agent
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 256
deterministic_eval: false
n_agents: 1
rl_framework: RLLIB_multi_agent
# rl_framework: SB3
n_agents: 2
agent_references:
- defender
- defender_1
- defender_2
io_settings:
save_checkpoints: true
@@ -36,31 +32,26 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_ref: client_2
applications:
- application_ref: client_2_web_browser
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
max_applications_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
start_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -69,38 +60,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DELETE"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -110,11 +83,12 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender1
- ref: defender_1
team: BLUE
type: ProxyAgent
@@ -544,7 +518,9 @@ agents:
agent_settings:
# ...
- ref: defender2
- ref: defender_2
team: BLUE
type: ProxyAgent
@@ -992,17 +968,25 @@ simulation:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
ip_address: 192.168.10.1
subnet_mask: 255.255.255.0
acl:
0:
18:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
19:
action: PERMIT
src_port: DNS
dst_port: DNS
20:
action: PERMIT
src_port: FTP
dst_port: FTP
21:
action: PERMIT
src_port: HTTP
dst_port: HTTP
22:
action: PERMIT
src_port: ARP
@@ -1039,7 +1023,7 @@ simulation:
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
@@ -1060,6 +1044,10 @@ simulation:
services:
- ref: database_service
type: DatabaseService
options:
backup_server_ip: 192.168.1.16
- ref: database_ftp_client
type: FTPClient
- ref: backup_server
type: server
@@ -1070,7 +1058,7 @@ simulation:
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
type: FTPServer
- ref: security_suite
type: server
@@ -1091,9 +1079,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient
@@ -1107,10 +1101,14 @@ simulation:
applications:
- ref: client_2_web_browser
type: WebBrowser
options:
target_url: http://arcd.com/users/
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
class PrimaiteError(Exception):
"""The root PrimAITe Error."""
"""The root PrimAITE Error."""
pass

View File

@@ -15,7 +15,6 @@ from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from primaite import getLogger
from primaite.simulator.sim_container import Simulation
_LOGGER = getLogger(__name__)
@@ -82,7 +81,7 @@ class NodeServiceAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "service_id": num_services}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, service_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -98,7 +97,7 @@ class NodeServiceScanAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "scan"
self.verb: str = "scan"
class NodeServiceStopAction(NodeServiceAbstractAction):
@@ -106,7 +105,7 @@ class NodeServiceStopAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "stop"
self.verb: str = "stop"
class NodeServiceStartAction(NodeServiceAbstractAction):
@@ -114,7 +113,7 @@ class NodeServiceStartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "start"
self.verb: str = "start"
class NodeServicePauseAction(NodeServiceAbstractAction):
@@ -122,7 +121,7 @@ class NodeServicePauseAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "pause"
self.verb: str = "pause"
class NodeServiceResumeAction(NodeServiceAbstractAction):
@@ -130,7 +129,7 @@ class NodeServiceResumeAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "resume"
self.verb: str = "resume"
class NodeServiceRestartAction(NodeServiceAbstractAction):
@@ -138,7 +137,7 @@ class NodeServiceRestartAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "restart"
self.verb: str = "restart"
class NodeServiceDisableAction(NodeServiceAbstractAction):
@@ -146,7 +145,7 @@ class NodeServiceDisableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "disable"
self.verb: str = "disable"
class NodeServiceEnableAction(NodeServiceAbstractAction):
@@ -154,7 +153,38 @@ class NodeServiceEnableAction(NodeServiceAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_services: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_services=num_services)
self.verb = "enable"
self.verb: str = "enable"
class NodeApplicationAbstractAction(AbstractAction):
"""
Base class for application actions.
Any action which applies to an application and uses node_id and application_id as its only two parameters can
inherit from this base class.
"""
@abstractmethod
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "application_id": num_applications}
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, application_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
node_uuid = self.manager.get_node_uuid_by_idx(node_id)
application_uuid = self.manager.get_application_uuid_by_idx(node_id, application_id)
if node_uuid is None or application_uuid is None:
return ["do_nothing"]
return ["network", "node", node_uuid, "application", application_uuid, self.verb]
class NodeApplicationExecuteAction(NodeApplicationAbstractAction):
"""Action which executes an application."""
def __init__(self, manager: "ActionManager", num_nodes: int, num_applications: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, num_applications=num_applications)
self.verb: str = "execute"
class NodeFolderAbstractAction(AbstractAction):
@@ -169,7 +199,7 @@ class NodeFolderAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -223,7 +253,7 @@ class NodeFileAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "folder_id": num_folders, "file_id": num_files}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, folder_id: int, file_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -240,7 +270,7 @@ class NodeFileScanAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "scan"
self.verb: str = "scan"
class NodeFileCheckhashAction(NodeFileAbstractAction):
@@ -248,7 +278,7 @@ class NodeFileCheckhashAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "checkhash"
self.verb: str = "checkhash"
class NodeFileDeleteAction(NodeFileAbstractAction):
@@ -256,7 +286,7 @@ class NodeFileDeleteAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "delete"
self.verb: str = "delete"
class NodeFileRepairAction(NodeFileAbstractAction):
@@ -264,7 +294,7 @@ class NodeFileRepairAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "repair"
self.verb: str = "repair"
class NodeFileRestoreAction(NodeFileAbstractAction):
@@ -272,7 +302,7 @@ class NodeFileRestoreAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "restore"
self.verb: str = "restore"
class NodeFileCorruptAction(NodeFileAbstractAction):
@@ -280,7 +310,7 @@ class NodeFileCorruptAction(NodeFileAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, num_folders: int, num_files: int, **kwargs) -> None:
super().__init__(manager, num_nodes=num_nodes, num_folders=num_folders, num_files=num_files, **kwargs)
self.verb = "corrupt"
self.verb: str = "corrupt"
class NodeAbstractAction(AbstractAction):
@@ -294,7 +324,7 @@ class NodeAbstractAction(AbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -307,7 +337,7 @@ class NodeOSScanAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "scan"
self.verb: str = "scan"
class NodeShutdownAction(NodeAbstractAction):
@@ -315,7 +345,7 @@ class NodeShutdownAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "shutdown"
self.verb: str = "shutdown"
class NodeStartupAction(NodeAbstractAction):
@@ -323,7 +353,7 @@ class NodeStartupAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "startup"
self.verb: str = "startup"
class NodeResetAction(NodeAbstractAction):
@@ -331,7 +361,7 @@ class NodeResetAction(NodeAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes)
self.verb = "reset"
self.verb: str = "reset"
class NetworkACLAddRuleAction(AbstractAction):
@@ -394,7 +424,7 @@ class NetworkACLAddRuleAction(AbstractAction):
elif permission == 2:
permission_str = "DENY"
else:
_LOGGER.warn(f"{self.__class__} received permission {permission}, expected 0 or 1.")
_LOGGER.warning(f"{self.__class__} received permission {permission}, expected 0 or 1.")
if protocol_id == 0:
return ["do_nothing"] # NOT SUPPORTED, JUST DO NOTHING IF WE COME ACROSS THIS
@@ -489,7 +519,7 @@ class NetworkNICAbstractAction(AbstractAction):
"""
super().__init__(manager=manager)
self.shape: Dict[str, int] = {"node_id": num_nodes, "nic_id": max_nics_per_node}
self.verb: str
self.verb: str # define but don't initialise: defends against children classes not defining this
def form_request(self, node_id: int, nic_id: int) -> List[str]:
"""Return the action formatted as a request which can be ingested by the PrimAITE simulation."""
@@ -512,7 +542,7 @@ class NetworkNICEnableAction(NetworkNICAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "enable"
self.verb: str = "enable"
class NetworkNICDisableAction(NetworkNICAbstractAction):
@@ -520,7 +550,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction):
def __init__(self, manager: "ActionManager", num_nodes: int, max_nics_per_node: int, **kwargs) -> None:
super().__init__(manager=manager, num_nodes=num_nodes, max_nics_per_node=max_nics_per_node, **kwargs)
self.verb = "disable"
self.verb: str = "disable"
class ActionManager:
@@ -536,6 +566,7 @@ class ActionManager:
"NODE_SERVICE_RESTART": NodeServiceRestartAction,
"NODE_SERVICE_DISABLE": NodeServiceDisableAction,
"NODE_SERVICE_ENABLE": NodeServiceEnableAction,
"NODE_APPLICATION_EXECUTE": NodeApplicationExecuteAction,
"NODE_FILE_SCAN": NodeFileScanAction,
"NODE_FILE_CHECKHASH": NodeFileCheckhashAction,
"NODE_FILE_DELETE": NodeFileDeleteAction,
@@ -562,9 +593,11 @@ class ActionManager:
game: "PrimaiteGame", # reference to game for information lookup
actions: List[str], # stores list of actions available to agent
node_uuids: List[str], # allows mapping index to node
application_uuids: List[List[str]], # allows mapping index to application
max_folders_per_node: int = 2, # allows calculating shape
max_files_per_folder: int = 2, # allows calculating shape
max_services_per_node: int = 2, # allows calculating shape
max_applications_per_node: int = 10, # allows calculating shape
max_nics_per_node: int = 8, # allows calculating shape
max_acl_rules: int = 10, # allows calculating shape
protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol
@@ -600,8 +633,8 @@ class ActionManager:
:type act_map: Optional[Dict[int, Dict]]
"""
self.game: "PrimaiteGame" = game
self.sim: Simulation = self.game.simulation
self.node_uuids: List[str] = node_uuids
self.application_uuids: List[List[str]] = application_uuids
self.protocols: List[str] = protocols
self.ports: List[str] = ports
@@ -611,7 +644,7 @@ class ActionManager:
else:
self.ip_address_list = []
for node_uuid in self.node_uuids:
node_obj = self.sim.network.nodes[node_uuid]
node_obj = self.game.simulation.network.nodes[node_uuid]
nics = node_obj.nics
for nic_uuid, nic_obj in nics.items():
self.ip_address_list.append(nic_obj.ip_address)
@@ -622,6 +655,7 @@ class ActionManager:
"num_folders": max_folders_per_node,
"num_files": max_files_per_folder,
"num_services": max_services_per_node,
"num_applications": max_applications_per_node,
"num_nics": max_nics_per_node,
"num_acl_rules": max_acl_rules,
"num_protocols": len(self.protocols),
@@ -734,7 +768,7 @@ class ActionManager:
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
node = self.game.simulation.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
return folder_uuids[folder_idx] if len(folder_uuids) > folder_idx else None
@@ -752,7 +786,7 @@ class ActionManager:
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
node = self.game.simulation.network.nodes[node_uuid]
folder_uuids = list(node.file_system.folders.keys())
if len(folder_uuids) <= folder_idx:
return None
@@ -771,10 +805,22 @@ class ActionManager:
:rtype: Optional[str]
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node = self.sim.network.nodes[node_uuid]
node = self.game.simulation.network.nodes[node_uuid]
service_uuids = list(node.services.keys())
return service_uuids[service_idx] if len(service_uuids) > service_idx else None
def get_application_uuid_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]:
"""Get the application UUID corresponding to the given node and service indices.
:param node_idx: The index of the node.
:type node_idx: int
:param application_idx: The index of the service on the node.
:type application_idx: int
:return: The UUID of the service. Or None if the node has fewer services than the given index.
:rtype: Optional[str]
"""
return self.application_uuids[node_idx][application_idx]
def get_internet_protocol_by_idx(self, protocol_idx: int) -> str:
"""Get the internet protocol corresponding to the given index.
@@ -819,7 +865,7 @@ class ActionManager:
:rtype: str
"""
node_uuid = self.get_node_uuid_by_idx(node_idx)
node_obj = self.sim.network.nodes[node_uuid]
node_obj = self.game.simulation.network.nodes[node_uuid]
nics = list(node_obj.nics.keys())
if len(nics) <= nic_idx:
return None

View File

@@ -0,0 +1,48 @@
import random
from typing import Dict, List, Tuple
from gymnasium.core import ObsType
from primaite.game.agent.interface import AbstractScriptedAgent
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
class DataManipulationAgent(AbstractScriptedAgent):
"""Agent that uses a DataManipulationBot to perform an SQL injection attack."""
data_manipulation_bots: List["DataManipulationBot"] = []
next_execution_timestep: int = 0
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._set_next_execution_timestep(self.agent_settings.start_settings.start_step)
def _set_next_execution_timestep(self, timestep: int) -> None:
"""Set the next execution timestep with a configured random variance.
:param timestep: The timestep to add variance to.
"""
random_timestep_increment = random.randint(
-self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance
)
self.next_execution_timestep = timestep + random_timestep_increment
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
:type obs: ObsType
:param reward: _description_, defaults to None
:type reward: float, optional
:return: _description_
:rtype: Tuple[str, Dict]
"""
current_timestep = self.action_manager.game.step_counter
if current_timestep < self.next_execution_timestep:
return "DONOTHING", {"dummy": 0}
self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency)
return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}

View File

@@ -1,13 +1,64 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel, model_validator
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
if TYPE_CHECKING:
pass
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
start_step: int = 5
"The timestep at which an agent begins performing it's actions"
frequency: int = 5
"The number of timesteps to wait between performing actions"
variance: int = 0
"The amount the frequency can randomly change to"
@model_validator(mode="after")
def check_variance_lt_frequency(self) -> "AgentStartSettings":
"""
Make sure variance is equal to or lower than frequency.
This is because the calculation for the next execution time is now + (frequency +- variance). If variance were
greater than frequency, sometimes the bracketed term would be negative and the attack would never happen again.
"""
if self.variance > self.frequency:
raise ValueError(
f"Agent start settings error: variance must be lower than frequency "
f"{self.variance=}, {self.frequency=}"
)
return self
class AgentSettings(BaseModel):
"""Settings for configuring the operation of an agent."""
start_settings: Optional[AgentStartSettings] = None
"Configuration for when an agent begins performing it's actions"
@classmethod
def from_config(cls, config: Optional[Dict]) -> "AgentSettings":
"""Construct agent settings from a config dictionary.
:param config: A dict of options for the agent settings.
:type config: Dict
:return: The agent settings.
:rtype: AgentSettings
"""
if config is None:
return cls()
return cls(**config)
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
@@ -18,6 +69,7 @@ class AbstractAgent(ABC):
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings] = None,
) -> None:
"""
Initialize an agent.
@@ -35,10 +87,7 @@ class AbstractAgent(ABC):
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
self.agent_settings = agent_settings or AgentSettings()
def update_observation(self, state: Dict) -> ObsType:
"""

View File

@@ -162,7 +162,7 @@ class ServiceObservation(AbstractObservation):
:return: Constructed service observation
:rtype: ServiceObservation
"""
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]].uuid])
return cls(where=parent_where + ["services", game.ref_map_services[config["service_ref"]]])
class LinkObservation(AbstractObservation):
@@ -263,8 +263,8 @@ class FolderObservation(AbstractObservation):
self.files.append(FileObservation())
while len(self.files) > num_files_per_folder:
truncated_file = self.files.pop()
msg = f"Too many files in folde observation. Truncating file {truncated_file}"
_LOGGER.warn(msg)
msg = f"Too many files in folder observation. Truncating file {truncated_file}"
_LOGGER.warning(msg)
self.default_observation = {
"health_status": 0,
@@ -438,7 +438,7 @@ class NodeObservation(AbstractObservation):
while len(self.services) > num_services_per_node:
truncated_service = self.services.pop()
msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}"
_LOGGER.warn(msg)
_LOGGER.warning(msg)
# truncate service list
self.folders: List[FolderObservation] = folders
@@ -448,7 +448,7 @@ class NodeObservation(AbstractObservation):
while len(self.folders) > num_folders_per_node:
truncated_folder = self.folders.pop()
msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}"
_LOGGER.warn(msg)
_LOGGER.warning(msg)
self.nics: List[NicObservation] = nics
while len(self.nics) < num_nics_per_node:
@@ -456,7 +456,7 @@ class NodeObservation(AbstractObservation):
while len(self.nics) > num_nics_per_node:
truncated_nic = self.nics.pop()
msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}"
_LOGGER.warn(msg)
_LOGGER.warning(msg)
self.logon_status: bool = logon_status

View File

@@ -210,16 +210,16 @@ class WebServer404Penalty(AbstractReward):
f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not "
"found in reward config."
)
_LOGGER.warn(msg)
_LOGGER.warning(msg)
return DummyReward() # TODO: should we error out with incorrect inputs? Probably!
node_uuid = game.ref_map_nodes[node_ref]
service_uuid = game.ref_map_services[service_ref].uuid
service_uuid = game.ref_map_services[service_ref]
if not (node_uuid and service_uuid):
msg = (
f"{cls.__name__} could not be initialised because node {node_ref} and service {service_ref} were not"
" found in the simulator."
)
_LOGGER.warn(msg)
_LOGGER.warning(msg)
return DummyReward() # TODO: consider erroring here as well
return cls(node_uuid=node_uuid, service_uuid=service_uuid)
@@ -238,7 +238,8 @@ class RewardFunction:
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float
self.current_reward: float = 0.0
self.total_reward: float = 0.0
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.

View File

@@ -1,5 +1,4 @@
"""PrimAITE game - Encapsulates the simulation and agents."""
from copy import deepcopy
from ipaddress import IPv4Address
from typing import Dict, List
@@ -7,10 +6,11 @@ from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.data_manipulation_bot import DataManipulationAgent
from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.base import NIC, NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
@@ -18,14 +18,14 @@ from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.sim_container import Simulation
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.services.service import Service
from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
@@ -57,9 +57,6 @@ class PrimaiteGame:
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
@@ -75,16 +72,16 @@ class PrimaiteGame:
self.options: PrimaiteGameOptions
"""Special options that apply for the entire game."""
self.ref_map_nodes: Dict[str, Node] = {}
self.ref_map_nodes: Dict[str, str] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
self.ref_map_services: Dict[str, str] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
self.ref_map_applications: Dict[str, str] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
self.ref_map_links: Dict[str, str] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
def step(self):
@@ -128,6 +125,7 @@ class PrimaiteGame:
for agent in self.agents:
agent.update_observation(state)
agent.update_reward(state)
agent.reward_function.total_reward += agent.reward_function.current_reward
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
@@ -157,7 +155,9 @@ class PrimaiteGame:
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}")
self.simulation = deepcopy(self._simulation_initial_state)
self.simulation.reset_component_for_episode(episode=self.episode_counter)
for agent in self.agents:
agent.reward_function.total_reward = 0.0
def close(self) -> None:
"""Close the game, this will close the simulation."""
@@ -187,10 +187,6 @@ class PrimaiteGame:
sim = game.simulation
net = sim.network
game.ref_map_nodes: Dict[str, Node] = {}
game.ref_map_services: Dict[str, Service] = {}
game.ref_map_links: Dict[str, Link] = {}
nodes_cfg = cfg["simulation"]["network"]["nodes"]
links_cfg = cfg["simulation"]["network"]["links"]
for node_cfg in nodes_cfg:
@@ -203,6 +199,7 @@ class PrimaiteGame:
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg["dns_server"],
operating_state=NodeOperatingState.ON,
)
elif n_type == "server":
new_node = Server(
@@ -211,16 +208,26 @@ class PrimaiteGame:
subnet_mask=node_cfg["subnet_mask"],
default_gateway=node_cfg["default_gateway"],
dns_server=node_cfg.get("dns_server"),
operating_state=NodeOperatingState.ON,
)
elif n_type == "switch":
new_node = Switch(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
new_node = Switch(
hostname=node_cfg["hostname"],
num_ports=node_cfg.get("num_ports"),
operating_state=NodeOperatingState.ON,
)
elif n_type == "router":
new_node = Router(hostname=node_cfg["hostname"], num_ports=node_cfg.get("num_ports"))
new_node = Router(
hostname=node_cfg["hostname"],
num_ports=node_cfg.get("num_ports"),
operating_state=NodeOperatingState.ON,
)
if "ports" in node_cfg:
for port_num, port_cfg in node_cfg["ports"].items():
new_node.configure_port(
port=port_num, ip_address=port_cfg["ip_address"], subnet_mask=port_cfg["subnet_mask"]
)
# new_node.enable_port(port_num)
if "acl" in node_cfg:
for r_num, r_cfg in node_cfg["acl"].items():
# excuse the uncommon walrus operator ` := `. It's just here as a shorthand, to avoid repeating
@@ -236,9 +243,10 @@ class PrimaiteGame:
position=r_num,
)
else:
print("invalid node type")
_LOGGER.warning(f"invalid node type {n_type} in config")
if "services" in node_cfg:
for service_cfg in node_cfg["services"]:
new_service = None
service_ref = service_cfg["ref"]
service_type = service_cfg["type"]
service_types_mapping = {
@@ -247,15 +255,16 @@ class PrimaiteGame:
"DatabaseClient": DatabaseClient,
"DatabaseService": DatabaseService,
"WebServer": WebServer,
"DataManipulationBot": DataManipulationBot,
"FTPClient": FTPClient,
"FTPServer": FTPServer,
}
if service_type in service_types_mapping:
print(f"installing {service_type} on node {new_node.hostname}")
_LOGGER.debug(f"installing {service_type} on node {new_node.hostname}")
new_node.software_manager.install(service_types_mapping[service_type])
new_service = new_node.software_manager.software[service_type]
game.ref_map_services[service_ref] = new_service
game.ref_map_services[service_ref] = new_service.uuid
else:
print(f"service type not found {service_type}")
_LOGGER.warning(f"service type not found {service_type}")
# service-dependent options
if service_type == "DatabaseClient":
if "options" in service_cfg:
@@ -268,30 +277,49 @@ class PrimaiteGame:
if "domain_mapping" in opt:
for domain, ip in opt["domain_mapping"].items():
new_service.dns_register(domain, ip)
if service_type == "DatabaseService":
if "options" in service_cfg:
opt = service_cfg["options"]
if "backup_server_ip" in opt:
new_service.configure_backup(backup_server=IPv4Address(opt["backup_server_ip"]))
new_service.start()
if "applications" in node_cfg:
for application_cfg in node_cfg["applications"]:
new_application = None
application_ref = application_cfg["ref"]
application_type = application_cfg["type"]
application_types_mapping = {
"WebBrowser": WebBrowser,
"DataManipulationBot": DataManipulationBot,
}
if application_type in application_types_mapping:
new_node.software_manager.install(application_types_mapping[application_type])
new_application = new_node.software_manager.software[application_type]
game.ref_map_applications[application_ref] = new_application
game.ref_map_applications[application_ref] = new_application.uuid
else:
print(f"application type not found {application_type}")
_LOGGER.warning(f"application type not found {application_type}")
if application_type == "DataManipulationBot":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.configure(
server_ip_address=IPv4Address(opt.get("server_ip")),
payload=opt.get("payload"),
port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")),
data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")),
)
elif application_type == "WebBrowser":
if "options" in application_cfg:
opt = application_cfg["options"]
new_application.target_url = opt.get("target_url")
if "nics" in node_cfg:
for nic_num, nic_cfg in node_cfg["nics"].items():
new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"]))
net.add_node(new_node)
new_node.power_on()
game.ref_map_nodes[
node_ref
] = (
new_node.uuid
) # TODO: fix inconsistency with service and link. Node gets added by uuid, but service by object
game.ref_map_nodes[node_ref] = new_node.uuid
# 2. create links between nodes
for link_cfg in links_cfg:
@@ -323,11 +351,25 @@ class PrimaiteGame:
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
action_space_cfg["options"]["application_uuids"] = []
# if a list of nodes is defined, convert them from node references to node UUIDs
for action_node_option in action_space_cfg.get("options", {}).pop("nodes", {}):
if "node_ref" in action_node_option:
node_uuid = game.ref_map_nodes[action_node_option["node_ref"]]
action_space_cfg["options"]["node_uuids"].append(node_uuid)
if "applications" in action_node_option:
node_application_uuids = []
for application_option in action_node_option["applications"]:
# TODO: fix inconsistency with node uuids and application uuids. The node object get added to
# node_uuid, whereas here the application gets added by uuid.
application_uuid = game.ref_map_applications[application_option["application_ref"]]
node_application_uuids.append(application_uuid)
action_space_cfg["options"]["application_uuids"].append(node_application_uuids)
else:
action_space_cfg["options"]["application_uuids"].append([])
# Each action space can potentially have a different list of nodes that it can apply to. Therefore,
# we will pass node_uuids as a part of the action space config.
# However, it's not possible to specify the node uuids directly in the config, as they are generated
@@ -345,6 +387,8 @@ class PrimaiteGame:
# CREATE REWARD FUNCTION
rew_function = RewardFunction.from_config(reward_function_cfg, game=game)
agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings"))
# CREATE AGENT
if agent_type == "GreenWebBrowsingAgent":
# TODO: implement non-random agents and fix this parsing
@@ -353,6 +397,7 @@ class PrimaiteGame:
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
elif agent_type == "ProxyAgent":
@@ -365,16 +410,17 @@ class PrimaiteGame:
game.agents.append(new_agent)
game.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
agent_settings=agent_settings,
)
game.agents.append(new_agent)
else:
print("agent type not found")
_LOGGER.warning(f"agent type {agent_type} not found")
game._simulation_initial_state = deepcopy(game.simulation) # noqa
game.simulation.set_original_state()
return game

View File

@@ -0,0 +1,16 @@
from random import random
def simulate_trial(p_of_success: float) -> bool:
"""
Simulates the outcome of a single trial in a Bernoulli process.
This function returns True with a probability 'p_of_success', simulating a success outcome in a single
trial of a Bernoulli process. When this function is executed multiple times, the set of outcomes follows
a binomial distribution. This is useful in scenarios where one needs to model or simulate events that
have two possible outcomes (success or failure) with a fixed probability of success.
:param p_of_success: The probability of success in a single trial, ranging from 0 to 1.
:returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False.
"""
return random() < p_of_success

View File

@@ -1,5 +1,21 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Multi agent system using RLLIB\n",
"\n",
"This notebook will demonstrate how to use the `PrimaiteRayMARLEnv` to train a very basic system with two PPO agents."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### First, Import packages and read our config file."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -8,75 +24,56 @@
"source": [
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"from primaite.session.environment import PrimaiteRayEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# gym = PrimaiteRayEnv({\"game\":game})"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray import air, tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
"with open(PRIMAITE_PATHS.user_config_path / 'example_config/example_config_2_rl_agents.yaml', 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"env_config = {\"game\":game}\n",
"ray.init(local_mode=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create a Ray algorithm config which accepts our two agents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"game\":game})\n",
" .rollouts(num_rollout_workers=0)\n",
" .multi_agent(\n",
" policies={agent.agent_name for agent in game.rl_agents},\n",
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
" )\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training\n",
"This example will save outputs to a default Ray directory and use mostly default settings."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -86,21 +83,11 @@
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"training_iteration\": 128},\n",
" checkpoint_config=air.CheckpointConfig(\n",
" checkpoint_frequency=10,\n",
" ),\n",
" stop={\"timesteps_total\": 512},\n",
" ),\n",
" param_space=config\n",
").fit()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {

View File

@@ -1,5 +1,13 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a Single agent system using RLLib\n",
"This notebook will demonstrate how to use PrimaiteRayEnv to train a basic PPO agent."
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -10,19 +18,25 @@
"import yaml\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.environment import PrimaiteRayEnv\n",
"from ray.rllib.algorithms import ppo\n",
"from ray import air, tune\n",
"import ray\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
"with open(example_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)\n",
"\n",
"game = PrimaiteGame.from_config(cfg)"
"ray.init(local_mode=True)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Create a Ray algorithm and pass it our config."
]
},
{
@@ -31,7 +45,21 @@
"metadata": {},
"outputs": [],
"source": [
"gym = PrimaiteRayEnv({\"game\":game})"
"env_config = {\"cfg\":cfg}\n",
"\n",
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .training(train_batch_size=128)\n",
")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Set training parameters and start the training"
]
},
{
@@ -40,61 +68,13 @@
"metadata": {},
"outputs": [],
"source": [
"import ray\n",
"from ray.rllib.algorithms import ppo"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ray.shutdown()\n",
"ray.init()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_config = {\"game\":game}\n",
"config = {\n",
" \"env\" : PrimaiteRayEnv,\n",
" \"env_config\" : env_config,\n",
" \"disable_env_checking\": True,\n",
" \"num_rollout_workers\": 0,\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo = ppo.PPO(config=config)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(5):\n",
" result = algo.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"algo.save(\"temp/deleteme\")"
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
]
}
],

View File

@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/cade/repos/PrimAITE/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"2023-11-26 23:25:47,985\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2023-11-26 23:25:51,213\tINFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
"2023-11-26 23:25:51,491\tWARNING __init__.py:10 -- PG has/have been moved to `rllib_contrib` and will no longer be maintained by the RLlib team. You can still use it/them normally inside RLlib util Ray 2.8, but from Ray 2.9 on, all `rllib_contrib` algorithms will no longer be part of the core repo, and will therefore have to be installed separately with pinned dependencies for e.g. ray[rllib] and other packages! See https://github.com/ray-project/ray/tree/master/rllib_contrib#rllib-contrib for more information on the RLlib contrib effort.\n"
]
}
],
"source": [
"from primaite.session.session import PrimaiteSession\n",
"from primaite.game.game import PrimaiteGame\n",
"from primaite.config.load import example_config_path\n",
"\n",
"from primaite.simulator.system.services.database.database_service import DatabaseService\n",
"\n",
"import yaml"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-11-26 23:25:51,579::ERROR::primaite.simulator.network.hardware.base::175::NIC a9:92:0a:5e:1b:e4/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,580::ERROR::primaite.simulator.network.hardware.base::175::NIC ef:03:23:af:3c:19/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,581::ERROR::primaite.simulator.network.hardware.base::175::NIC ae:cf:83:2f:94:17/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,582::ERROR::primaite.simulator.network.hardware.base::175::NIC 4c:b2:99:e2:4a:5d/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,583::ERROR::primaite.simulator.network.hardware.base::175::NIC b9:eb:f9:c2:17:2f/127.0.0.1 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,590::ERROR::primaite.simulator.network.hardware.base::175::NIC cb:df:ca:54:be:01/192.168.1.10 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,595::ERROR::primaite.simulator.network.hardware.base::175::NIC 6e:32:12:da:4d:0d/192.168.1.12 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,600::ERROR::primaite.simulator.network.hardware.base::175::NIC 58:6e:9b:a7:68:49/192.168.1.14 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,604::ERROR::primaite.simulator.network.hardware.base::175::NIC 33:db:a6:40:dd:a3/192.168.1.16 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,608::ERROR::primaite.simulator.network.hardware.base::175::NIC 72:aa:2b:c0:4c:5f/192.168.1.110 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,610::ERROR::primaite.simulator.network.hardware.base::175::NIC 11:d7:0e:90:d9:a4/192.168.10.110 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,614::ERROR::primaite.simulator.network.hardware.base::175::NIC 86:2b:a4:e5:4d:0f/192.168.10.21 cannot be enabled as it is not connected to a Link\n",
"2023-11-26 23:25:51,631::ERROR::primaite.simulator.network.hardware.base::175::NIC af:ad:8f:84:f1:db/192.168.10.22 cannot be enabled as it is not connected to a Link\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"installing DNSServer on node domain_controller\n",
"installing DatabaseClient on node web_server\n",
"installing WebServer on node web_server\n",
"installing DatabaseService on node database_server\n",
"installing FTPClient on node database_server\n",
"installing FTPServer on node backup_server\n",
"installing DNSClient on node client_1\n",
"installing DNSClient on node client_2\n"
]
}
],
"source": [
"\n",
"with open(example_config_path(),'r') as cfgfile:\n",
" cfg = yaml.safe_load(cfgfile)\n",
"game = PrimaiteGame.from_config(cfg)\n",
"net = game.simulation.network\n",
"database_server = net.get_node_by_hostname('database_server')\n",
"web_server = net.get_node_by_hostname('web_server')\n",
"client_1 = net.get_node_by_hostname('client_1')\n",
"\n",
"db_service = database_server.software_manager.software[\"DatabaseService\"]\n",
"db_client = web_server.software_manager.software[\"DatabaseClient\"]\n",
"# db_client.run()\n",
"db_manipulation_bot = client_1.software_manager.software[\"DataManipulationBot\"]\n",
"db_manipulation_bot.port_scan_p_of_success=1.0\n",
"db_manipulation_bot.data_manipulation_p_of_success=1.0\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"db_client.run()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_service.backup_database()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"db_manipulation_bot.run()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"False"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_service.restore_backup()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"db_client.query(\"SELECT\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"db_manipulation_bot.run()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"client_1.ping(database_server.ethernet_port[1].ip_address)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from pydantic import validate_call, BaseModel"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"class A(BaseModel):\n",
" x:int\n",
"\n",
" @validate_call\n",
" def increase_x(self, by:int) -> None:\n",
" self.x += 1"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"my_a = A(x=3)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"ename": "ValidationError",
"evalue": "1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"\u001b[1;32m/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb Cell 15\u001b[0m line \u001b[0;36m1\n\u001b[0;32m----> <a href='vscode-notebook-cell://wsl%2Bubuntu/home/cade/repos/PrimAITE/src/primaite/notebooks/uc2_demo.ipynb#X23sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0'>1</a>\u001b[0m my_a\u001b[39m.\u001b[39;49mincrease_x(\u001b[39m3.2\u001b[39;49m)\n",
"File \u001b[0;32m~/repos/PrimAITE/venv/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py:91\u001b[0m, in \u001b[0;36mValidateCallWrapper.__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 90\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39m__call__\u001b[39m(\u001b[39mself\u001b[39m, \u001b[39m*\u001b[39margs: Any, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwargs: Any) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m Any:\n\u001b[0;32m---> 91\u001b[0m res \u001b[39m=\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m__pydantic_validator__\u001b[39m.\u001b[39;49mvalidate_python(pydantic_core\u001b[39m.\u001b[39;49mArgsKwargs(args, kwargs))\n\u001b[1;32m 92\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__:\n\u001b[1;32m 93\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39m__return_pydantic_validator__\u001b[39m.\u001b[39mvalidate_python(res)\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for increase_x\n0\n Input should be a valid integer, got a number with a fractional part [type=int_from_float, input_value=3.2, input_type=float]\n For further information visit https://errors.pydantic.dev/2.1/v/int_from_float"
]
}
],
"source": [
"my_a.increase_x(3.2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -37,11 +37,14 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {}
return next_obs, reward, terminated, truncated, info
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
print(
f"Resetting environment, episode {self.game.episode_counter}, "
f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}"
)
self.game.reset()
state = self.game.get_sim_state()
self.game.update_agents(state)
@@ -69,14 +72,15 @@ class PrimaiteGymEnv(gymnasium.Env):
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict[str, PrimaiteGame]) -> None:
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.env = PrimaiteGymEnv(game=env_config["game"])
self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"]))
self.env.game.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
@@ -92,14 +96,14 @@ class PrimaiteRayEnv(gymnasium.Env):
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
def __init__(self, env_config: Optional[Dict] = None) -> None:
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict[str, PrimaiteGame]
"""
self.game: PrimaiteGame = env_config["game"]
self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"])
"""Reference to the primaite game"""
self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents}
"""List of all possible agents in the environment. This list should not change!"""
@@ -108,7 +112,10 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{name: agent.observation_manager.space for name, agent in self.agents.items()}
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
@@ -159,4 +166,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
return {name: agent.observation_manager.current_observation for name, agent in self.agents.items()}
obs = {}
for name, agent in self.agents.items():
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs

View File

@@ -12,6 +12,10 @@ from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
from primaite import getLogger
_LOGGER = getLogger(__name__)
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
@@ -19,7 +23,7 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
config = {
self.config = {
"env": PrimaiteRayEnv,
"env_config": {"game": session.game},
"disable_env_checking": True,
@@ -29,12 +33,13 @@ class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
ray.shutdown()
ray.init()
self._algo = ppo.PPO(config=config)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
for ep in range(n_episodes):
self._algo.train()
self.config["training_iterations"] = n_episodes * timesteps_per_episode
self.config["train_batch_size"] = 128
self._algo = ppo.PPO(config=self.config)
_LOGGER.info("Starting RLLIB training session")
self._algo.train()
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""

View File

@@ -51,14 +51,13 @@ class SB3Policy(PolicyABC, identifier="SB3"):
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
reward_data = evaluate_policy(
_ = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
print(reward_data)
def save(self, save_path: Path) -> None:
"""

View File

@@ -62,6 +62,7 @@ class PrimaiteSession:
def start_session(self) -> None:
"""Commence the training/eval session."""
print("Starting Primaite Session")
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
@@ -88,16 +89,16 @@ class PrimaiteSession:
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
io_manager = SessionIO(SessionIOSettings(**io_settings))
game = PrimaiteGame.from_config(cfg)
sess = cls(game=game)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config={"game": game})

View File

@@ -113,7 +113,7 @@ class RequestManager(BaseModel):
"""
if name in self.request_types:
msg = f"Overwriting request type {name}."
_LOGGER.warning(msg)
_LOGGER.debug(msg)
self.request_types[name] = request_type
@@ -153,6 +153,8 @@ class SimComponent(BaseModel):
uuid: str
"""The component UUID."""
_original_state: Dict = {}
def __init__(self, **kwargs):
if not kwargs.get("uuid"):
kwargs["uuid"] = str(uuid4())
@@ -160,6 +162,16 @@ class SimComponent(BaseModel):
self._request_manager: RequestManager = self._init_request_manager()
self._parent: Optional["SimComponent"] = None
# @abstractmethod
def set_original_state(self):
"""Sets the original state."""
pass
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for key, value in self._original_state.items():
self.__setattr__(key, value)
def _init_request_manager(self) -> RequestManager:
"""
Initialise the request manager for this component.
@@ -227,14 +239,6 @@ class SimComponent(BaseModel):
"""
pass
def reset_component_for_episode(self, episode: int):
"""
Reset this component to its original state for a new episode.
Override this method with anything that needs to happen within the component for it to be reset.
"""
pass
@property
def parent(self) -> "SimComponent":
"""Reference to the parent object which manages this object.

View File

@@ -42,6 +42,19 @@ class Account(SimComponent):
"Account Type, currently this can be service account (used by apps) or user account."
enabled: bool = True
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"num_logons",
"num_logoffs",
"num_group_changes",
"username",
"password",
"account_type",
"enabled",
}
self._original_state = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -59,7 +72,7 @@ class Account(SimComponent):
"num_group_changes": self.num_group_changes,
"username": self.username,
"password": self.password,
"account_type": self.account_type.name,
"account_type": self.account_type.value,
"enabled": self.enabled,
}
)

View File

@@ -73,6 +73,20 @@ class File(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})")
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}")
super().set_original_state()
vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}")
super().reset_component_for_episode(episode)
@property
def path(self) -> str:
"""

View File

@@ -35,6 +35,45 @@ class FileSystem(SimComponent):
if not self.folders:
self.create_folder("root")
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}")
for folder in self.folders.values():
folder.set_original_state()
# Capture a list of all 'original' file uuids
original_keys = list(self.folders.keys())
vals_to_include = {"sim_root"}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_folder_uuids"] = original_keys
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}")
# Move any 'original' folder that have been deleted back to folders
original_folder_uuids = self._original_state["original_folder_uuids"]
for uuid in original_folder_uuids:
if uuid in self.deleted_folders:
folder = self.deleted_folders[uuid]
self.deleted_folders.pop(uuid)
self.folders[uuid] = folder
self._folders_by_name[folder.name] = folder
# Clear any other deleted folders that aren't original (have been created by agent)
self.deleted_folders.clear()
# Now clear all non-original folders created by agent
current_folder_uuids = list(self.folders.keys())
for uuid in current_folder_uuids:
if uuid not in original_folder_uuids:
folder = self.folders[uuid]
self.folders.pop(uuid)
self._folders_by_name.pop(folder.name)
# Now reset all remaining folders
for folder in self.folders.values():
folder.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()

View File

@@ -85,6 +85,11 @@ class FileSystemItemABC(SimComponent):
deleted: bool = False
"If true, the FileSystemItem was deleted."
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red"}
self._original_state = self.model_dump(include=vals_to_keep)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.

View File

@@ -51,6 +51,51 @@ class Folder(FileSystemItemABC):
self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})")
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}")
for file in self.files.values():
file.set_original_state()
super().set_original_state()
vals_to_include = {
"scan_duration",
"scan_countdown",
"red_scan_duration",
"red_scan_countdown",
"restore_duration",
"restore_countdown",
}
self._original_state.update(self.model_dump(include=vals_to_include))
self._original_state["original_file_uuids"] = list(self.files.keys())
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}")
# Move any 'original' file that have been deleted back to files
original_file_uuids = self._original_state["original_file_uuids"]
for uuid in original_file_uuids:
if uuid in self.deleted_files:
file = self.deleted_files[uuid]
self.deleted_files.pop(uuid)
self.files[uuid] = file
self._files_by_name[file.name] = file
# Clear any other deleted files that aren't original (have been created by agent)
self.deleted_files.clear()
# Now clear all non-original files created by agent
current_file_uuids = list(self.files.keys())
for uuid in current_file_uuids:
if uuid not in original_file_uuids:
file = self.files[uuid]
self.files.pop(uuid)
self._files_by_name.pop(file.name)
# Now reset all remaining files
for file in self.files.values():
file.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(

View File

@@ -12,6 +12,8 @@ from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
@@ -43,6 +45,32 @@ class Network(SimComponent):
self._nx_graph = MultiGraph()
def set_original_state(self):
"""Sets the original state."""
for node in self.nodes.values():
node.set_original_state()
for link in self.links.values():
link.set_original_state()
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
for node in self.nodes.values():
node.reset_component_for_episode(episode)
for link in self.links.values():
link.reset_component_for_episode(episode)
for node in self.nodes.values():
node.power_on()
for nic in node.nics.values():
nic.enable()
# Reset software
for software in node.software_manager.software.values():
if isinstance(software, Service):
software.start()
elif isinstance(software, Application):
software.run()
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
self._node_request_manager = RequestManager()
@@ -52,6 +80,17 @@ class Network(SimComponent):
)
return rm
def apply_timestep(self, timestep: int) -> None:
"""Apply a timestep evolution to this the network and its nodes and links."""
super().apply_timestep(timestep=timestep)
# apply timestep to nodes
for node_id in self.nodes:
self.nodes[node_id].apply_timestep(timestep=timestep)
# apply timestep to links
for link_id in self.links:
self.links[link_id].apply_timestep(timestep=timestep)
@property
def routers(self) -> List[Router]:
"""The Routers in the Network."""
@@ -181,7 +220,7 @@ class Network(SimComponent):
self._node_id_map[len(self.nodes)] = node
node.parent = self
self._nx_graph.add_node(node.hostname)
_LOGGER.info(f"Added node {node.uuid} to Network {self.uuid}")
_LOGGER.debug(f"Added node {node.uuid} to Network {self.uuid}")
self._node_request_manager.add_request(name=node.uuid, request_type=RequestType(func=node._request_manager))
def get_node_by_hostname(self, hostname: str) -> Optional[Node]:

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import re
import secrets
from enum import Enum
from ipaddress import IPv4Address, IPv4Network
from pathlib import Path
from typing import Any, Dict, Literal, Optional, Tuple, Union
@@ -15,6 +14,7 @@ from primaite.simulator import SIM_OUTPUT
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.domain.account import Account
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.protocols.arp import ARPEntry, ARPPacket
from primaite.simulator.network.transmission.data_link_layer import EthernetHeader, Frame
from primaite.simulator.network.transmission.network_layer import ICMPPacket, ICMPType, IPPacket, IPProtocol
@@ -121,6 +121,21 @@ class NIC(SimComponent):
_LOGGER.error(msg)
raise ValueError(msg)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
if episode and self.pcap:
self.pcap.current_episode = episode
self.pcap.setup_logger()
self.enable()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -166,13 +181,13 @@ class NIC(SimComponent):
if self.enabled:
return
if not self._connected_node:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Node")
_LOGGER.debug(f"NIC {self} cannot be enabled as it is not connected to a Node")
return
if self._connected_node.operating_state != NodeOperatingState.ON:
self._connected_node.sys_log.error(f"NIC {self} cannot be enabled as the endpoint is not turned on")
return
if not self._connected_link:
_LOGGER.error(f"NIC {self} cannot be enabled as it is not connected to a Link")
_LOGGER.debug(f"NIC {self} cannot be enabled as it is not connected to a Link")
return
self.enabled = True
@@ -308,6 +323,14 @@ class SwitchPort(SimComponent):
kwargs["mac_address"] = generate_mac_address()
super().__init__(**kwargs)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -454,6 +477,14 @@ class Link(SimComponent):
self.endpoint_b.connect_link(self)
self.endpoint_up()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"bandwidth", "current_load"}
self._original_state = self.model_dump(include=vals_to_include)
super().set_original_state()
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -536,15 +567,6 @@ class Link(SimComponent):
return True
return False
def reset_component_for_episode(self, episode: int):
"""
Link reset function.
Reset:
- returns the link current_load to 0.
"""
self.current_load = 0
def __str__(self) -> str:
return f"{self.endpoint_a}<-->{self.endpoint_b}"
@@ -584,6 +606,10 @@ class ARPCache:
)
print(table)
def clear(self):
"""Clears the arp cache."""
self.arp.clear()
def add_arp_cache_entry(self, ip_address: IPv4Address, mac_address: str, nic: NIC, override: bool = False):
"""
Add an ARP entry to the cache.
@@ -756,6 +782,10 @@ class ICMP:
self.arp: ARPCache = arp_cache
self.request_replies = {}
def clear(self):
"""Clears the ICMP request replies tracker."""
self.request_replies.clear()
def process_icmp(self, frame: Frame, from_nic: NIC, is_reattempt: bool = False):
"""
Process an ICMP packet, including handling echo requests and replies.
@@ -856,19 +886,6 @@ class ICMP:
return sequence, icmp_packet.identifier
class NodeOperatingState(Enum):
"""Enumeration of Node Operating States."""
ON = 1
"The node is powered on."
OFF = 2
"The node is powered off."
BOOTING = 3
"The node is in the process of booting up."
SHUTTING_DOWN = 4
"The node is in the process of shutting down."
class Node(SimComponent):
"""
A basic Node class that represents a node on the network.
@@ -972,6 +989,60 @@ class Node(SimComponent):
self.arp.nics = self.nics
self.session_manager.software_manager = self.software_manager
self._install_system_software()
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
for software in self.software_manager.software.values():
software.set_original_state()
self.file_system.set_original_state()
for nic in self.nics.values():
nic.set_original_state()
vals_to_include = {
"hostname",
"default_gateway",
"operating_state",
"revealed_to_red",
"start_up_duration",
"start_up_countdown",
"shut_down_duration",
"shut_down_countdown",
"is_resetting",
"node_scan_duration",
"node_scan_countdown",
"red_scan_countdown",
}
self._original_state = self.model_dump(include=vals_to_include)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
super().reset_component_for_episode(episode)
# Reset ARP Cache
self.arp.clear()
# Reset ICMP
self.icmp.clear()
# Reset Session Manager
self.session_manager.clear()
# Reset File System
self.file_system.reset_component_for_episode(episode)
# Reset all Nics
for nic in self.nics.values():
nic.reset_component_for_episode(episode)
for software in self.software_manager.software.values():
software.reset_component_for_episode(episode)
if episode and self.sys_log:
self.sys_log.current_episode = episode
self.sys_log.setup_logger()
def _init_request_manager(self) -> RequestManager:
# TODO: I see that this code is really confusing and hard to read right now... I think some of these things will
@@ -1090,18 +1161,21 @@ class Node(SimComponent):
else:
if self.operating_state == NodeOperatingState.BOOTING:
self.operating_state = NodeOperatingState.ON
self.sys_log.info("Turned on")
self.sys_log.info(f"{self.hostname}: Turned on")
for nic in self.nics.values():
if nic._connected_link:
nic.enable()
self._start_up_actions()
# count down to shut down
if self.shut_down_countdown > 0:
self.shut_down_countdown -= 1
else:
if self.operating_state == NodeOperatingState.SHUTTING_DOWN:
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Turned off")
self.sys_log.info(f"{self.hostname}: Turned off")
self._shut_down_actions()
# if resetting turn back on
if self.is_resetting:
@@ -1197,6 +1271,7 @@ class Node(SimComponent):
self.start_up_countdown = self.start_up_duration
if self.start_up_duration <= 0:
self._start_up_actions()
self.operating_state = NodeOperatingState.ON
self.sys_log.info("Turned on")
for nic in self.nics.values():
@@ -1212,6 +1287,7 @@ class Node(SimComponent):
self.shut_down_countdown = self.shut_down_duration
if self.shut_down_duration <= 0:
self._shut_down_actions()
self.operating_state = NodeOperatingState.OFF
self.sys_log.info("Turned off")
@@ -1418,103 +1494,35 @@ class Node(SimComponent):
_LOGGER.info(f"Removed application {application.uuid} from node {self.uuid}")
self._application_request_manager.remove_request(application.uuid)
def _shut_down_actions(self):
"""Actions to perform when the node is shut down."""
# Turn off all the services in the node
for service_id in self.services:
self.services[service_id].stop()
# Turn off all the applications in the node
for app_id in self.applications:
self.applications[app_id].close()
# Turn off all processes in the node
# for process_id in self.processes:
# self.processes[process_id]
def _start_up_actions(self):
"""Actions to perform when the node is starting up."""
# Turn on all the services in the node
for service_id in self.services:
self.services[service_id].start()
# Turn on all the applications in the node
for app_id in self.applications:
self.applications[app_id].run()
# Turn off all processes in the node
# for process_id in self.processes:
# self.processes[process_id]
def __contains__(self, item: Any) -> bool:
if isinstance(item, Service):
return item.uuid in self.services
return None
class Switch(Node):
"""A class representing a Layer 2 network switch."""
num_ports: int = 24
"The number of ports on the switch."
switch_ports: Dict[int, SwitchPort] = {}
"The SwitchPorts on the switch."
mac_address_table: Dict[str, SwitchPort] = {}
"A MAC address table mapping destination MAC addresses to corresponding SwitchPorts."
def __init__(self, **kwargs):
super().__init__(**kwargs)
if not self.switch_ports:
self.switch_ports = {i: SwitchPort() for i in range(1, self.num_ports + 1)}
for port_num, port in self.switch_ports.items():
port._connected_node = self
port.parent = self
port.port_num = port_num
def show(self):
"""Prints a table of the SwitchPorts on the Switch."""
table = PrettyTable(["Port", "MAC Address", "Speed", "Status"])
for port_num, port in self.switch_ports.items():
table.add_row([port_num, port.mac_address, port.speed, "Enabled" if port.enabled else "Disabled"])
print(table)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation.
:return: Current state of this object and child objects.
:rtype: Dict
"""
return {
"uuid": self.uuid,
"num_ports": self.num_ports, # redundant?
"ports": {port_num: port.describe_state() for port_num, port in self.switch_ports.items()},
"mac_address_table": {mac: port for mac, port in self.mac_address_table.items()},
}
def _add_mac_table_entry(self, mac_address: str, switch_port: SwitchPort):
mac_table_port = self.mac_address_table.get(mac_address)
if not mac_table_port:
self.mac_address_table[mac_address] = switch_port
self.sys_log.info(f"Added MAC table entry: Port {switch_port.port_num} -> {mac_address}")
else:
if mac_table_port != switch_port:
self.mac_address_table.pop(mac_address)
self.sys_log.info(f"Removed MAC table entry: Port {mac_table_port.port_num} -> {mac_address}")
self._add_mac_table_entry(mac_address, switch_port)
def forward_frame(self, frame: Frame, incoming_port: SwitchPort):
"""
Forward a frame to the appropriate port based on the destination MAC address.
:param frame: The Frame to be forwarded.
:param incoming_port: The port number from which the frame was received.
"""
src_mac = frame.ethernet.src_mac_addr
dst_mac = frame.ethernet.dst_mac_addr
self._add_mac_table_entry(src_mac, incoming_port)
outgoing_port = self.mac_address_table.get(dst_mac)
if outgoing_port or dst_mac != "ff:ff:ff:ff:ff:ff":
outgoing_port.send_frame(frame)
else:
# If the destination MAC is not in the table, flood to all ports except incoming
for port in self.switch_ports.values():
if port != incoming_port:
port.send_frame(frame)
def disconnect_link_from_port(self, link: Link, port_number: int):
"""
Disconnect a given link from the specified port number on the switch.
:param link: The Link object to be disconnected.
:param port_number: The port number on the switch from where the link should be disconnected.
:raise NetworkError: When an invalid port number is provided or the link does not match the connection.
"""
port = self.switch_ports.get(port_number)
if port is None:
msg = f"Invalid port number {port_number} on the switch"
_LOGGER.error(msg)
raise NetworkError(msg)
if port._connected_link != link:
msg = f"The link does not match the connection at port number {port_number}"
_LOGGER.error(msg)
raise NetworkError(msg)
port.disconnect_link()

View File

@@ -0,0 +1,14 @@
from enum import Enum
class NodeOperatingState(Enum):
"""Enumeration of Node Operating States."""
ON = 1
"The node is powered on."
OFF = 2
"The node is powered off."
BOOTING = 3
"The node is in the process of booting up."
SHUTTING_DOWN = 4
"The node is in the process of shutting down."

View File

@@ -52,6 +52,11 @@ class ACLRule(SimComponent):
rule_strings.append(f"{key}={value}")
return ", ".join(rule_strings)
def set_original_state(self):
"""Sets the original state."""
vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def describe_state(self) -> Dict:
"""
Describes the current state of the ACLRule.
@@ -93,6 +98,18 @@ class AccessControlList(SimComponent):
super().__init__(**kwargs)
self._acl = [None] * (self.max_acl_rules - 1)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.implicit_rule.set_original_state()
vals_to_keep = {"implicit_action", "max_acl_rules", "acl"}
self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True)
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.implicit_rule.reset_component_for_episode(episode)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
@@ -337,6 +354,11 @@ class RouteEntry(SimComponent):
kwargs[key] = IPv4Address(kwargs[key])
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"}
self._original_values = self.model_dump(include=vals_to_include)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteEntry.
@@ -368,6 +390,18 @@ class RouteTable(SimComponent):
routes: List[RouteEntry] = []
sys_log: SysLog
def set_original_state(self):
"""Sets the original state."""
"""Sets the original state."""
super().set_original_state()
self._original_state["routes_orig"] = self.routes
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.routes.clear()
self.routes = self._original_state["routes_orig"]
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the RouteTable.
@@ -638,6 +672,27 @@ class Router(Node):
self.arp.nics = self.nics
self.icmp.arp = self.arp
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
self.acl.set_original_state()
self.route_table.set_original_state()
super().set_original_state()
vals_to_include = {"num_ports"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.arp.clear()
self.acl.reset_component_for_episode(episode)
self.route_table.reset_component_for_episode(episode)
for i, nic in self.ethernet_ports.items():
nic.reset_component_for_episode(episode)
self.enable_port(i)
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("acl", RequestType(func=self.acl._request_manager))
@@ -730,6 +785,7 @@ class Router(Node):
dst_ip_address=dst_ip_address,
dst_port=dst_port,
)
if not permitted:
at_port = self._get_port_of_nic(from_nic)
self.sys_log.info(f"Frame blocked at port {at_port} by rule {rule}")
@@ -763,6 +819,7 @@ class Router(Node):
nic.ip_address = ip_address
nic.subnet_mask = subnet_mask
self.sys_log.info(f"Configured port {port} with ip_address={ip_address}/{nic.ip_network.prefixlen}")
self.set_original_state()
def enable_port(self, port: int):
"""

View File

@@ -51,14 +51,22 @@ def client_server_routed() -> Network:
# Client 1
client_1 = Computer(
hostname="client_1", ip_address="192.168.2.2", subnet_mask="255.255.255.0", default_gateway="192.168.2.1"
hostname="client_1",
ip_address="192.168.2.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.2.1",
operating_state=NodeOperatingState.ON,
)
client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
# Server 1
server_1 = Server(
hostname="server_1", ip_address="192.168.1.2", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
hostname="server_1",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
server_1.power_on()
network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
@@ -139,8 +147,13 @@ def arcd_uc2_network() -> Network:
client_1.power_on()
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
db_manipulation_bot.configure(
server_ip_address=IPv4Address("192.168.1.14"),
payload="DELETE",
port_scan_p_of_success=1.0,
data_manipulation_p_of_success=1.0,
)
# Client 2
client_2 = Computer(
@@ -152,6 +165,8 @@ def arcd_uc2_network() -> Network:
operating_state=NodeOperatingState.ON,
)
client_2.power_on()
web_browser = client_2.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
# Domain Controller
@@ -234,7 +249,7 @@ def arcd_uc2_network() -> Network:
# noqa
]
database_server.software_manager.install(DatabaseService)
database_service: DatabaseService = database_server.software_manager.software["DatabaseService"] # noqa
database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa
database_service.start()
database_service.configure_backup(backup_server=IPv4Address("192.168.1.16"))
database_service._process_sql(ddl, None) # noqa
@@ -253,7 +268,7 @@ def arcd_uc2_network() -> Network:
web_server.power_on()
web_server.software_manager.install(DatabaseClient)
database_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
database_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address("192.168.1.14"))
network.connect(endpoint_b=web_server.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
database_client.run()
@@ -262,7 +277,7 @@ def arcd_uc2_network() -> Network:
web_server.software_manager.install(WebServer)
# register the web_server to a domain
dns_server_service: DNSServer = domain_controller.software_manager.software["DNSServer"] # noqa
dns_server_service: DNSServer = domain_controller.software_manager.software.get("DNSServer") # noqa
dns_server_service.dns_register("arcd.com", web_server.ip_address)
# Backup Server

View File

@@ -35,6 +35,9 @@ class FTPCommand(Enum):
class FTPStatusCode(Enum):
"""Status code of the current FTP request."""
NOT_FOUND = 14
"""Destination not found."""
OK = 200
"""Command successful."""

View File

@@ -1,4 +1,4 @@
from enum import Enum
from enum import Enum, IntEnum
from primaite.simulator.network.protocols.packet import DataPacket
@@ -25,7 +25,7 @@ class HttpRequestMethod(Enum):
"""Apply partial modifications to a resource."""
class HttpStatusCode(Enum):
class HttpStatusCode(IntEnum):
"""List of available HTTP Statuses."""
OK = 200

View File

@@ -5,6 +5,8 @@ def convert_bytes_to_megabits(B: Union[int, float]) -> float: # noqa - Keep it
"""
Convert Bytes (file size) to Megabits (data transfer).
Technically Mebibits - but for simplicity sake, we'll call it megabit
:param B: The file size in Bytes.
:return: File bits to transfer in Megabits.
"""

View File

@@ -9,7 +9,7 @@ class Simulation(SimComponent):
"""Top-level simulation object which holds a reference to all other parts of the simulation."""
network: Network
domain: DomainController
# domain: DomainController
def __init__(self, **kwargs):
"""Initialise the Simulation."""
@@ -21,6 +21,14 @@ class Simulation(SimComponent):
super().__init__(**kwargs)
def set_original_state(self):
"""Sets the original state."""
self.network.set_original_state()
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.network.reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
# pass through network requests to the network objects

View File

@@ -2,8 +2,11 @@ from abc import abstractmethod
from enum import Enum
from typing import Any, Dict, Set
from primaite import getLogger
from primaite.simulator.system.software import IOSoftware, SoftwareHealthState
_LOGGER = getLogger(__name__)
class ApplicationOperatingState(Enum):
"""Enumeration of Application Operating States."""
@@ -38,6 +41,12 @@ class Application(IOSoftware):
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""
@@ -51,7 +60,7 @@ class Application(IOSoftware):
state = super().describe_state()
state.update(
{
"opearting_state": self.operating_state.value,
"operating_state": self.operating_state.value,
"execution_control_status": self.execution_control_status,
"num_executions": self.num_executions,
"groups": list(self.groups),
@@ -59,12 +68,38 @@ class Application(IOSoftware):
)
return state
def _can_perform_action(self) -> bool:
"""
Checks if the application can perform actions.
This is done by checking if the application is operating properly or the node it is installed
in is operational.
Returns true if the software can perform actions.
"""
if not super()._can_perform_action():
return False
if self.operating_state is not self.operating_state.RUNNING:
# service is not running
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
return False
return True
def run(self) -> None:
"""Open the Application."""
if not super()._can_perform_action():
return
if self.operating_state == ApplicationOperatingState.CLOSED:
self.sys_log.info(f"Running Application {self.name}")
self.operating_state = ApplicationOperatingState.RUNNING
def _application_loop(self):
"""The main application loop."""
pass
def close(self) -> None:
"""Close the Application."""
if self.operating_state == ApplicationOperatingState.RUNNING:
@@ -78,15 +113,6 @@ class Application(IOSoftware):
self.sys_log.info(f"Installing Application {self.name}")
self.operating_state = ApplicationOperatingState.INSTALLING
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.
@@ -97,4 +123,4 @@ class Application(IOSoftware):
:param payload: The payload to receive.
:return: True if successful, False otherwise.
"""
pass
return super().receive(payload=payload, session_id=session_id, **kwargs)

View File

@@ -31,6 +31,20 @@ class DatabaseClient(Application):
kwargs["port"] = Port.POSTGRES_SERVER
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
self.set_original_state()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
self._query_success_tracker.clear()
def describe_state(self) -> Dict:
"""
@@ -54,9 +68,13 @@ class DatabaseClient(Application):
def connect(self) -> bool:
"""Connect to a Database Service."""
if not self.connected and self.operating_state.RUNNING:
if not self._can_perform_action():
return False
if not self.connected:
return self._connect(self.server_ip_address, self.server_password)
return False
# already connected
return True
def _connect(
self, server_ip_address: IPv4Address, password: Optional[str] = None, is_reattempt: bool = False
@@ -75,11 +93,11 @@ class DatabaseClient(Application):
"""
if is_reattempt:
if self.connected:
self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} authorised")
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} authorised")
self.server_ip_address = server_ip_address
return self.connected
else:
self.sys_log.info(f"{self.name}: DatabaseClient connected to {server_ip_address} declined")
self.sys_log.info(f"{self.name}: DatabaseClient connection to {server_ip_address} declined")
return False
payload = {"type": "connect_request", "password": password}
software_manager: SoftwareManager = self.software_manager
@@ -90,7 +108,7 @@ class DatabaseClient(Application):
def disconnect(self):
"""Disconnect from the Database Service."""
if self.connected and self.operating_state.RUNNING:
if self.connected and self.operating_state is ApplicationOperatingState.RUNNING:
software_manager: SoftwareManager = self.software_manager
software_manager.send_payload_to_session_manager(
payload={"type": "disconnect"}, dest_ip_address=self.server_ip_address, dest_port=self.port
@@ -132,22 +150,34 @@ class DatabaseClient(Application):
def run(self) -> None:
"""Run the DatabaseClient."""
super().run()
self.operating_state = ApplicationOperatingState.RUNNING
self.connect()
if self.operating_state == ApplicationOperatingState.RUNNING:
self.connect()
def query(self, sql: str) -> bool:
def query(self, sql: str, is_reattempt: bool = False) -> bool:
"""
Send a query to the Database Service.
:param sql: The SQL query.
:param: sql: The SQL query.
:param: is_reattempt: If true, the action has been reattempted.
:return: True if the query was successful, otherwise False.
"""
if self.connected and self.operating_state.RUNNING:
if not self._can_perform_action():
return False
if self.connected:
query_id = str(uuid4())
# Initialise the tracker of this ID to False
self._query_success_tracker[query_id] = False
return self._query(sql=sql, query_id=query_id)
else:
if is_reattempt:
return False
if not self.connect():
return False
self.query(sql=sql, is_reattempt=True)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
@@ -157,6 +187,9 @@ class DatabaseClient(Application):
:param session_id: The session id the payload relates to.
:return: True.
"""
if not self._can_perform_action():
return False
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_response":
self.connected = payload["response"] == True
@@ -166,4 +199,6 @@ class DatabaseClient(Application):
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
_LOGGER.debug(f"Received payload {payload}")
else:
self.connected = False
return True

View File

@@ -2,12 +2,21 @@ from ipaddress import IPv4Address
from typing import Dict, Optional
from urllib.parse import urlparse
from primaite.simulator.network.protocols.http import HttpRequestMethod, HttpRequestPacket, HttpResponsePacket
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
HttpRequestPacket,
HttpResponsePacket,
HttpStatusCode,
)
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.services.dns.dns_client import DNSClient
_LOGGER = getLogger(__name__)
class WebBrowser(Application):
"""
@@ -16,6 +25,8 @@ class WebBrowser(Application):
The application requests and loads web pages using its domain name and requesting IP addresses using DNS.
"""
target_url: Optional[str] = None
domain_name_ip_address: Optional[IPv4Address] = None
"The IP address of the domain name for the webpage."
@@ -30,8 +41,29 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.set_original_state()
self.run()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa
)
return rm
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of the WebBrowser.
@@ -42,16 +74,9 @@ class WebBrowser(Application):
state["last_response_status_code"] = self.latest_response.status_code if self.latest_response else None
def reset_component_for_episode(self, episode: int):
"""
Resets the Application component for a new episode.
"""Reset the original state of the SimComponent."""
This method ensures the Application is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
self.domain_name_ip_address = None
self.latest_response = None
def get_webpage(self, url: str) -> bool:
def get_webpage(self, url: Optional[str] = None) -> bool:
"""
Retrieve the webpage.
@@ -60,8 +85,12 @@ class WebBrowser(Application):
:param: url: The address of the web page the browser requests
:type: url: str
"""
url = url or self.target_url
if not self._can_perform_action():
return False
# reset latest response
self.latest_response = None
self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND)
try:
parsed_url = urlparse(url)
@@ -70,8 +99,7 @@ class WebBrowser(Application):
return False
# get the IP address of the domain name via DNS
dns_client: DNSClient = self.software_manager.software["DNSClient"]
dns_client: DNSClient = self.software_manager.software.get("DNSClient")
domain_exists = dns_client.check_domain_exists(target_domain=parsed_url.hostname)
# if domain does not exist, the request fails
@@ -91,11 +119,19 @@ class WebBrowser(Application):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url=url)
# send request
return self.send(
if self.send(
payload=payload,
dest_ip_address=self.domain_name_ip_address,
dest_port=parsed_url.port if parsed_url.port else Port.HTTP,
)
):
self.sys_log.info(
f"{self.name}: Received HTTP {payload.request_method.name} "
f"Response {payload.request_url} - {self.latest_response.status_code.value}"
)
return self.latest_response.status_code is HttpStatusCode.OK
else:
self.sys_log.error(f"Error sending Http Packet {str(payload)}")
return False
def send(
self,

View File

@@ -34,9 +34,12 @@ class PacketCapture:
"The IP address associated with the PCAP logs."
self.switch_port_number = switch_port_number
"The SwitchPort number."
self._setup_logger()
def _setup_logger(self):
self.current_episode: int = 1
self.setup_logger()
def setup_logger(self):
"""Set up the logger configuration."""
log_path = self._get_log_path()
@@ -75,7 +78,7 @@ class PacketCapture:
def _get_log_path(self) -> Path:
"""Get the path for the log file."""
root = SIM_OUTPUT.path / self.hostname
root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self._logger_name}.log"

View File

@@ -93,6 +93,11 @@ class SessionManager:
"""
pass
def clear(self):
"""Clears the sessions."""
self.sessions_by_key.clear()
self.sessions_by_uuid.clear()
@staticmethod
def _get_session_key(
frame: Frame, inbound_frame: bool = True

View File

@@ -31,9 +31,10 @@ class SysLog:
:param hostname: The hostname associated with the system logs being recorded.
"""
self.hostname = hostname
self._setup_logger()
self.current_episode: int = 1
self.setup_logger()
def _setup_logger(self):
def setup_logger(self):
"""
Configures the logger for this SysLog instance.
@@ -80,7 +81,7 @@ class SysLog:
:return: Path object representing the location of the log file.
"""
root = SIM_OUTPUT.path / self.hostname
root = SIM_OUTPUT.path / f"episode_{self.current_episode}" / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"

View File

@@ -24,6 +24,12 @@ class Process(Software):
operating_state: ProcessOperatingState
"The current operating state of the Process."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""

View File

@@ -2,6 +2,7 @@ from datetime import datetime
from ipaddress import IPv4Address
from typing import Any, Dict, List, Literal, Optional, Union
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
@@ -10,6 +11,8 @@ from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.service import Service, ServiceOperatingState
from primaite.simulator.system.software import SoftwareHealthState
_LOGGER = getLogger(__name__)
class DatabaseService(Service):
"""
@@ -38,6 +41,25 @@ class DatabaseService(Service):
self._db_file: File
self._create_db_file()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"password",
"connections",
"backup_server",
"latest_backup_directory",
"latest_backup_file_name",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}")
self.connections.clear()
super().reset_component_for_episode(episode)
def configure_backup(self, backup_server: IPv4Address):
"""
Set up the database backup.
@@ -48,13 +70,17 @@ class DatabaseService(Service):
def backup_database(self) -> bool:
"""Create a backup of the database to the configured backup server."""
# check if this action can be performed
if not self._can_perform_action():
return False
# check if the backup server was configured
if self.backup_server is None:
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
# send backup copy of database file to FTP server
response = ftp_client_service.send_file(
@@ -73,8 +99,12 @@ class DatabaseService(Service):
def restore_backup(self) -> bool:
"""Restore a backup from backup server."""
# check if this action can be performed
if not self._can_perform_action():
return False
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
ftp_client_service: FTPClient = software_manager.software.get("FTPClient")
# retrieve backup file from backup server
response = ftp_client_service.request_file(
@@ -173,6 +203,9 @@ class DatabaseService(Service):
:param session_id: The session identifier.
:return: True if the Status Code is 200, otherwise False.
"""
if not super().receive(payload=payload, session_id=session_id, **kwargs):
return False
result = {"status_code": 500, "data": []}
if isinstance(payload, dict) and payload.get("type"):
if payload["type"] == "connect_request":

View File

@@ -29,6 +29,18 @@ class DNSClient(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"dns_server"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_cache.clear()
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
@@ -42,23 +54,18 @@ class DNSClient(Service):
state = super().describe_state()
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address):
def add_domain_to_cache(self, domain_name: str, ip_address: IPv4Address) -> bool:
"""
Adds a domain name to the DNS Client cache.
:param: domain_name: The domain name to save to cache
:param: ip_address: The IP Address to attach the domain name to
"""
if not self._can_perform_action():
return False
self.dns_cache[domain_name] = ip_address
return True
def check_domain_exists(
self,
@@ -72,6 +79,9 @@ class DNSClient(Service):
:param: session_id: The Session ID the payload is to originate from. Optional.
:param: is_reattempt: Checks if the request has been reattempted. Default is False.
"""
if not self._can_perform_action():
return False
# check if DNS server is configured
if self.dns_server is None:
self.sys_log.error(f"{self.name}: DNS Server is not configured")

View File

@@ -28,6 +28,20 @@ class DNSServer(Service):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"dns_table"}
self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"]
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
self.dns_table.clear()
for key, value in self._original_state["dns_table_orig"].items():
self.dns_table[key] = value
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Describes the current state of the software.
@@ -48,6 +62,9 @@ class DNSServer(Service):
:param target_domain: The single domain name requested by a DNS client.
:return ip_address: The IP address of that domain name or None.
"""
if not self._can_perform_action():
return
return self.dns_table.get(target_domain)
def dns_register(self, domain_name: str, domain_ip_address: IPv4Address):
@@ -60,17 +77,11 @@ class DNSServer(Service):
:param: domain_ip_address: The IP address that the domain should route to
:type: domain_ip_address: IPv4Address
"""
if not self._can_perform_action():
return
self.dns_table[domain_name] = domain_ip_address
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def receive(
self,
payload: Any,
@@ -88,10 +99,14 @@ class DNSServer(Service):
:return: True if DNS request returns a valid IP, otherwise, False
"""
if not super().receive(payload=payload, session_id=session_id, **kwargs):
return False
# The payload should be a DNS packet
if not isinstance(payload, DNSPacket):
_LOGGER.debug(f"{payload} is not a DNSPacket")
return False
# cast payload into a DNS packet
payload: DNSPacket = payload
if payload.dns_request is not None:

View File

@@ -1,13 +1,15 @@
from ipaddress import IPv4Address
from typing import Optional
from primaite import getLogger
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.software_manager import SoftwareManager
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import ServiceOperatingState
_LOGGER = getLogger(__name__)
class FTPClient(FTPServiceABC):
@@ -28,6 +30,18 @@ class FTPClient(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"connected"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.
@@ -38,8 +52,7 @@ class FTPClient(FTPServiceABC):
:type: session_id: Optional[str]
"""
# if client service is down, return error
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.error("FTP Client is not running")
if not self._can_perform_action():
payload.status_code = FTPStatusCode.ERROR
return payload
@@ -66,16 +79,12 @@ class FTPClient(FTPServiceABC):
:type: is_reattempt: Optional[bool]
"""
# make sure the service is running before attempting
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.error(f"FTPClient not running for {self.sys_log.hostname}")
if not self._can_perform_action():
return False
# normally FTP will choose a random port for the transfer, but using the FTP command port will do for now
# create FTP packet
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=Port.FTP,
)
payload: FTPPacket = FTPPacket(ftp_command=FTPCommand.PORT, ftp_command_args=Port.FTP)
if self.send(payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id):
if payload.status_code == FTPStatusCode.OK:
@@ -270,8 +279,14 @@ class FTPClient(FTPServiceABC):
This helps prevent an FTP request loop - FTP client and servers can exist on
the same node.
"""
if payload.status_code is None:
if not self._can_perform_action():
return False
if payload.status_code is None:
self.sys_log.error(f"FTP Server could not be found - Error Code: {FTPStatusCode.NOT_FOUND.value}")
return False
self.sys_log.info(f"{self.name}: Received FTP Response {payload.ftp_command.name} {payload.status_code.value}")
self._process_ftp_command(payload=payload, session_id=session_id)
return True

View File

@@ -1,11 +1,13 @@
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite import getLogger
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ftp.ftp_service import FTPServiceABC
from primaite.simulator.system.services.service import ServiceOperatingState
_LOGGER = getLogger(__name__)
class FTPServer(FTPServiceABC):
@@ -29,6 +31,19 @@ class FTPServer(FTPServiceABC):
super().__init__(**kwargs)
self.start()
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"server_password"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}")
self.connections.clear()
super().reset_component_for_episode(episode)
def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket:
"""
Process the command in the FTP Packet.
@@ -42,8 +57,7 @@ class FTPServer(FTPServiceABC):
payload.status_code = FTPStatusCode.ERROR
# if server service is down, return error
if self.operating_state != ServiceOperatingState.RUNNING:
self.sys_log.error("FTP Server not running")
if not self._can_perform_action():
return payload
self.sys_log.info(f"{self.name}: Received FTP {payload.ftp_command.name} {payload.ftp_command_args}")
@@ -79,6 +93,9 @@ class FTPServer(FTPServiceABC):
self.sys_log.error(f"{payload} is not an FTP packet")
return False
if not super().receive(payload=payload, session_id=session_id, **kwargs):
return False
"""
Ignore ftp payload if status code is defined.

View File

@@ -1,27 +1,90 @@
from enum import IntEnum
from ipaddress import IPv4Address
from typing import Optional
from primaite import getLogger
from primaite.game.science import simulate_trial
from primaite.simulator.core import RequestManager, RequestType
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
_LOGGER = getLogger(__name__)
class DataManipulationAttackStage(IntEnum):
"""
Enumeration representing different stages of a data manipulation attack.
This enumeration defines the various stages a data manipulation attack can be in during its lifecycle in the
simulation. Each stage represents a specific phase in the attack process.
"""
NOT_STARTED = 0
"Indicates that the attack has not started yet."
LOGON = 1
"The stage where logon procedures are simulated."
PORT_SCAN = 2
"Represents the stage of performing a horizontal port scan on the target."
ATTACKING = 3
"Stage of actively attacking the target."
SUCCEEDED = 4
"Indicates the attack has been successfully completed."
FAILED = 5
"Signifies that the attack has failed."
class DataManipulationBot(DatabaseClient):
"""
Red Agent Data Integration Service.
The Service represents a bot that causes files/folders in the File System to
become corrupted.
"""
"""A bot that simulates a script which performs a SQL injection attack."""
server_ip_address: Optional[IPv4Address] = None
payload: Optional[str] = None
server_password: Optional[str] = None
port_scan_p_of_success: float = 0.1
data_manipulation_p_of_success: float = 0.1
attack_stage: DataManipulationAttackStage = DataManipulationAttackStage.NOT_STARTED
repeat: bool = False
"Whether to repeat attacking once finished."
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.name = "DataManipulationBot"
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {
"server_ip_address",
"payload",
"server_password",
"port_scan_p_of_success",
"data_manipulation_p_of_success",
"attack_stage",
"repeat",
}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run()))
return rm
def configure(
self, server_ip_address: IPv4Address, server_password: Optional[str] = None, payload: Optional[str] = None
self,
server_ip_address: IPv4Address,
server_password: Optional[str] = None,
payload: Optional[str] = None,
port_scan_p_of_success: float = 0.1,
data_manipulation_p_of_success: float = 0.1,
repeat: bool = False,
):
"""
Configure the DataManipulatorBot to communicate with a DatabaseService.
@@ -29,23 +92,111 @@ class DataManipulationBot(DatabaseClient):
:param server_ip_address: The IP address of the Node the DatabaseService is on.
:param server_password: The password on the DatabaseService.
:param payload: The data manipulation query payload.
:param port_scan_p_of_success: The probability of success for the port scan stage.
:param data_manipulation_p_of_success: The probability of success for the data manipulation stage.
:param repeat: Whether to repeat attacking once finished.
"""
self.server_ip_address = server_ip_address
self.payload = payload
self.server_password = server_password
self.port_scan_p_of_success = port_scan_p_of_success
self.data_manipulation_p_of_success = data_manipulation_p_of_success
self.repeat = repeat
self.sys_log.info(
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}."
f"{self.name}: Configured the {self.name} with {server_ip_address=}, {payload=}, {server_password=}, "
f"{repeat=}."
)
def _logon(self):
"""
Simulate the logon process as the initial stage of the attack.
Advances the attack stage to `LOGON` if successful.
"""
if self.attack_stage == DataManipulationAttackStage.NOT_STARTED:
# Bypass this stage as we're not dealing with logon for now
self.sys_log.info(f"{self.name}: ")
self.attack_stage = DataManipulationAttackStage.LOGON
def _perform_port_scan(self, p_of_success: Optional[float] = 0.1):
"""
Perform a simulated port scan to check for open SQL ports.
Advances the attack stage to `PORT_SCAN` if successful.
:param p_of_success: Probability of successful port scan, by default 0.1.
"""
if self.attack_stage == DataManipulationAttackStage.LOGON:
# perform a port scan to identify that the SQL port is open on the server
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing port scan")
# perform the port scan
port_is_open = True # Temporary; later we can implement NMAP port scan.
if port_is_open:
self.sys_log.info(f"{self.name}: ")
self.attack_stage = DataManipulationAttackStage.PORT_SCAN
def _perform_data_manipulation(self, p_of_success: Optional[float] = 0.1):
"""
Execute the data manipulation attack on the target.
Advances the attack stage to `COMPLETE` if successful, or 'FAILED' if unsuccessful.
:param p_of_success: Probability of successfully performing data manipulation, by default 0.1.
"""
if self.attack_stage == DataManipulationAttackStage.PORT_SCAN:
# perform the actual data manipulation attack
if simulate_trial(p_of_success):
self.sys_log.info(f"{self.name}: Performing data manipulation")
# perform the attack
if not self.connected:
self.connect()
if self.connected:
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
attack_successful = True
if attack_successful:
self.sys_log.info(f"{self.name}: Data manipulation successful")
self.attack_stage = DataManipulationAttackStage.SUCCEEDED
else:
self.sys_log.info(f"{self.name}: Data manipulation failed")
self.attack_stage = DataManipulationAttackStage.FAILED
def run(self):
"""Run the DataManipulationBot."""
if self.server_ip_address and self.payload:
self.sys_log.info(f"{self.name}: Attempting to start the {self.name}")
super().run()
if not self.connected:
self.connect()
if self.connected:
self.query(self.payload)
self.sys_log.info(f"{self.name} payload delivered: {self.payload}")
"""
Run the Data Manipulation Bot.
Calls the parent classes execute method before starting the application loop.
"""
super().run()
self._application_loop()
def _application_loop(self):
"""
The main application loop of the bot, handling the attack process.
This is the core loop where the bot sequentially goes through the stages of the attack.
"""
if self.operating_state != ApplicationOperatingState.RUNNING:
return
if self.server_ip_address and self.payload and self.operating_state:
self.sys_log.info(f"{self.name}: Running")
self._logon()
self._perform_port_scan(p_of_success=self.port_scan_p_of_success)
self._perform_data_manipulation(p_of_success=self.data_manipulation_p_of_success)
if self.repeat and self.attack_stage in (
DataManipulationAttackStage.SUCCEEDED,
DataManipulationAttackStage.FAILED,
):
self.attack_stage = DataManipulationAttackStage.NOT_STARTED
else:
self.sys_log.error(f"Failed to start the {self.name} as it requires both a target_ip_address and payload.")
self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.")
def apply_timestep(self, timestep: int) -> None:
"""
Apply a timestep to the bot, triggering the application loop.
:param timestep: The timestep value to update the bot's state.
"""
self._application_loop()

View File

@@ -1,5 +1,5 @@
from enum import Enum
from typing import Dict, Optional
from typing import Any, Dict, Optional
from primaite import getLogger
from primaite.simulator.core import RequestManager, RequestType
@@ -40,12 +40,52 @@ class Service(IOSoftware):
restart_countdown: Optional[int] = None
"If currently restarting, how many timesteps remain until the restart is finished."
def _can_perform_action(self) -> bool:
"""
Checks if the service can perform actions.
This is done by checking if the service is operating properly or the node it is installed
in is operational.
Returns true if the software can perform actions.
"""
if not super()._can_perform_action():
return False
if self.operating_state is not self.operating_state.RUNNING:
# service is not running
_LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}")
return False
return True
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receives a payload from the SessionManager.
The specifics of how the payload is processed and whether a response payload
is generated should be implemented in subclasses.
:param payload: The payload to receive.
:param session_id: The identifier of the session that the payload is associated with.
:param kwargs: Additional keyword arguments specific to the implementation.
:return: True if the payload was successfully received and processed, False otherwise.
"""
return super().receive(payload=payload, session_id=session_id, **kwargs)
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.health_state_visible = SoftwareHealthState.UNUSED
self.health_state_actual = SoftwareHealthState.UNUSED
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"operating_state", "restart_duration", "restart_countdown"}
self._original_state.update(self.model_dump(include=vals_to_include))
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request("scan", RequestType(func=lambda request, context: self.scan()))
@@ -69,19 +109,10 @@ class Service(IOSoftware):
"""
state = super().describe_state()
state["operating_state"] = self.operating_state.value
state["health_state_actual"] = self.health_state_actual
state["health_state_visible"] = self.health_state_visible
state["health_state_actual"] = self.health_state_actual.value
state["health_state_visible"] = self.health_state_visible.value
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the Service component for a new episode.
This method ensures the Service is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues.
"""
pass
def stop(self) -> None:
"""Stop the service."""
if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]:
@@ -91,6 +122,10 @@ class Service(IOSoftware):
def start(self, **kwargs) -> None:
"""Start the service."""
# cant start the service if the node it is on is off
if not super()._can_perform_action():
return
if self.operating_state == ServiceOperatingState.STOPPED:
self.sys_log.info(f"Starting service {self.name}")
self.operating_state = ServiceOperatingState.RUNNING

View File

@@ -2,6 +2,7 @@ from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from urllib.parse import urlparse
from primaite import getLogger
from primaite.simulator.network.protocols.http import (
HttpRequestMethod,
HttpRequestPacket,
@@ -13,12 +14,26 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.service import Service
_LOGGER = getLogger(__name__)
class WebServer(Service):
"""Class used to represent a Web Server Service in simulation."""
last_response_status_code: Optional[HttpStatusCode] = None
def set_original_state(self):
"""Sets the original state."""
_LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}")
super().set_original_state()
vals_to_include = {"last_response_status_code"}
self._original_state.update(self.model_dump(include=vals_to_include))
def reset_component_for_episode(self, episode: int):
"""Reset the original state of the SimComponent."""
_LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}")
super().reset_component_for_episode(episode)
def describe_state(self) -> Dict:
"""
Produce a dictionary describing the current state of this object.
@@ -30,7 +45,7 @@ class WebServer(Service):
"""
state = super().describe_state()
state["last_response_status_code"] = (
self.last_response_status_code.value if self.last_response_status_code else None
self.last_response_status_code.value if isinstance(self.last_response_status_code, HttpStatusCode) else None
)
return state
@@ -104,7 +119,7 @@ class WebServer(Service):
if path.startswith("users"):
# get data from DatabaseServer
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
db_client: DatabaseClient = self.software_manager.software.get("DatabaseClient")
# get all users
if db_client.query("SELECT"):
# query succeeded
@@ -155,6 +170,9 @@ class WebServer(Service):
:param: payload: The payload to send.
:param: session_id: The id of the session. Optional.
"""
if not super().receive(payload=payload, session_id=session_id, **kwargs):
return False
# check if the payload is an HTTPPacket
if not isinstance(payload, HttpRequestPacket):
self.sys_log.error("Payload is not an HTTPPacket")

View File

@@ -3,8 +3,9 @@ from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from primaite.simulator.core import RequestManager, RequestType, SimComponent
from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent
from primaite.simulator.file_system.file_system import FileSystem, Folder
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.core.session_manager import Session
from primaite.simulator.system.core.sys_log import SysLog
@@ -89,6 +90,19 @@ class Software(SimComponent):
folder: Optional[Folder] = None
"The folder on the file system the Software uses."
def set_original_state(self):
"""Sets the original state."""
vals_to_include = {
"name",
"health_state_actual",
"health_state_visible",
"criticality",
"patching_count",
"scanning_count",
"revealed_to_red",
}
self._original_state = self.model_dump(include=vals_to_include)
def _init_request_manager(self) -> RequestManager:
rm = super()._init_request_manager()
rm.add_request(
@@ -131,16 +145,6 @@ class Software(SimComponent):
)
return state
def reset_component_for_episode(self, episode: int):
"""
Resets the software component for a new episode.
This method should ensure the software is ready for a new episode, including resetting any
stateful properties or statistics, and clearing any message queues. The specifics of what constitutes a
"reset" should be implemented in subclasses.
"""
pass
def set_health_state(self, health_state: SoftwareHealthState) -> None:
"""
Assign a new health state to this software.
@@ -203,6 +207,12 @@ class IOSoftware(Software):
port: Port
"The port to which the software is connected."
def set_original_state(self):
"""Sets the original state."""
super().set_original_state()
vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"}
self._original_state.update(self.model_dump(include=vals_to_include))
@abstractmethod
def describe_state(self) -> Dict:
"""
@@ -225,6 +235,21 @@ class IOSoftware(Software):
)
return state
@abstractmethod
def _can_perform_action(self) -> bool:
"""
Checks if the software can perform actions.
This is done by checking if the software is operating properly or the node it is installed
in is operational.
Returns true if the software can perform actions.
"""
if self.software_manager and self.software_manager.node.operating_state is NodeOperatingState.OFF:
_LOGGER.debug(f"{self.name} Error: {self.software_manager.node.hostname} is not online.")
return False
return True
def send(
self,
payload: Any,
@@ -243,6 +268,9 @@ class IOSoftware(Software):
:return: True if successful, False otherwise.
"""
if not self._can_perform_action():
return False
return self.software_manager.send_payload_to_session_manager(
payload=payload, dest_ip_address=dest_ip_address, dest_port=dest_port, session_id=session_id
)
@@ -261,4 +289,5 @@ class IOSoftware(Software):
:param kwargs: Additional keyword arguments specific to the implementation.
:return: True if the payload was successfully received and processed, False otherwise.
"""
pass
# return false if not allowed to perform actions
return self._can_perform_action()

View File

@@ -27,14 +27,6 @@ agents:
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
@@ -48,10 +40,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -60,38 +53,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -101,9 +76,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -652,9 +628,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -52,10 +52,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -64,50 +65,32 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -656,9 +639,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -58,10 +58,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -70,38 +71,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -111,9 +94,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender1
team: BLUE
@@ -1093,9 +1077,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -56,10 +56,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -68,38 +69,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -109,9 +92,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -660,9 +644,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -52,10 +52,11 @@ agents:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: client_1_data_manipulation_red_bot
team: RED
@@ -64,38 +65,20 @@ agents:
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
nodes: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
applications:
- application_ref: data_manipulation_bot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
@@ -105,9 +88,10 @@ agents:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
@@ -656,9 +640,15 @@ simulation:
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
applications:
- ref: data_manipulation_bot
type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- ref: client_1_dns_client
type: DNSClient

View File

@@ -1,6 +1,6 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Union
from typing import Any, Dict, Tuple, Union
import pytest
import yaml
@@ -12,7 +12,13 @@ from primaite.session.session import PrimaiteSession
# from primaite.environment.primaite_env import Primaite
# from primaite.primaite_session import PrimaiteSession
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.hardware.nodes.switch import Switch
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.sys_log import SysLog
@@ -28,12 +34,18 @@ from primaite import PRIMAITE_PATHS
# PrimAITE v3 stuff
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.base import Link, Node
class TestService(Service):
"""Test Service class"""
def __init__(self, **kwargs):
kwargs["name"] = "TestService"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
pass
@@ -41,6 +53,12 @@ class TestService(Service):
class TestApplication(Application):
"""Test Application class"""
def __init__(self, **kwargs):
kwargs["name"] = "TestApplication"
kwargs["port"] = Port.HTTP
kwargs["protocol"] = IPProtocol.TCP
super().__init__(**kwargs)
def describe_state(self) -> Dict:
pass
@@ -57,6 +75,11 @@ def service(file_system) -> TestService:
)
@pytest.fixture(scope="function")
def service_class():
return TestService
@pytest.fixture(scope="function")
def application(file_system) -> TestApplication:
return TestApplication(
@@ -64,6 +87,11 @@ def application(file_system) -> TestApplication:
)
@pytest.fixture(scope="function")
def application_class():
return TestApplication
@pytest.fixture(scope="function")
def file_system() -> FileSystem:
return Node(hostname="fs_node").file_system
@@ -99,3 +127,110 @@ def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
config_path = request.param[0]
return TempPrimaiteSession.from_config(config_path=config_path)
@pytest.fixture(scope="function")
def client_server() -> Tuple[Computer, Server]:
# Create Computer
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.0.1",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
# Create Server
server = Server(
hostname="server", ip_address="192.168.0.2", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON
)
# Connect Computer and Server
computer_nic = computer.nics[next(iter(computer.nics))]
server_nic = server.nics[next(iter(server.nics))]
link = Link(endpoint_a=computer_nic, endpoint_b=server_nic)
# Should be linked
assert link.is_up
return computer, server
@pytest.fixture(scope="function")
def example_network() -> Network:
"""
Create the network used for testing.
Should only contain the nodes and links.
This would act as the base network and services and applications are installed in the relevant test file,
-------------- --------------
| client_1 |----- ----| server_1 |
-------------- | -------------- -------------- -------------- | --------------
------| switch_1 |------| router_1 |------| switch_2 |------
-------------- | -------------- -------------- -------------- | --------------
| client_2 |---- ----| server_2 |
-------------- --------------
"""
network = Network()
# Router 1
router_1 = Router(hostname="router_1", num_ports=5, operating_state=NodeOperatingState.ON)
router_1.configure_port(port=1, ip_address="192.168.1.1", subnet_mask="255.255.255.0")
router_1.configure_port(port=2, ip_address="192.168.10.1", subnet_mask="255.255.255.0")
# Switch 1
switch_1 = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON)
network.connect(endpoint_a=router_1.ethernet_ports[1], endpoint_b=switch_1.switch_ports[8])
router_1.enable_port(1)
# Switch 2
switch_2 = Switch(hostname="switch_2", num_ports=8, operating_state=NodeOperatingState.ON)
network.connect(endpoint_a=router_1.ethernet_ports[2], endpoint_b=switch_2.switch_ports[8])
router_1.enable_port(2)
# Client 1
client_1 = Computer(
hostname="client_1",
ip_address="192.168.10.21",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
operating_state=NodeOperatingState.ON,
)
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
# Client 2
client_2 = Computer(
hostname="client_2",
ip_address="192.168.10.22",
subnet_mask="255.255.255.0",
default_gateway="192.168.10.1",
operating_state=NodeOperatingState.ON,
)
network.connect(endpoint_b=client_2.ethernet_port[1], endpoint_a=switch_2.switch_ports[2])
# Domain Controller
server_1 = Server(
hostname="server_1",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
network.connect(endpoint_b=server_1.ethernet_port[1], endpoint_a=switch_1.switch_ports[1])
# Database Server
server_2 = Server(
hostname="server_2",
ip_address="192.168.1.14",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
network.connect(endpoint_b=server_2.ethernet_port[1], endpoint_a=switch_1.switch_ports[2])
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22)
router_1.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23)
return network

View File

@@ -2,6 +2,7 @@
import tempfile
from pathlib import Path
import pytest
import yaml
from stable_baselines3 import PPO
@@ -10,6 +11,7 @@ from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
# @pytest.mark.skip(reason="no way of currently testing this")
def test_sb3_compatibility():
"""Test that the Gymnasium environment can be used with an SB3 agent."""
with open(example_config_path(), "r") as f:

View File

@@ -11,6 +11,7 @@ MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
# @pytest.mark.skip(reason="no way of currently testing this")
class TestPrimaiteSession:
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_creating_session(self, temp_primaite_session):
@@ -76,6 +77,10 @@ class TestPrimaiteSession:
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.skip(
reason="Currently software cannot be dynamically created/destroyed during simulation. Therefore, "
"reset doesn't implement software restore."
)
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:

View File

@@ -8,13 +8,13 @@ from primaite.simulator.system.services.red_services.data_manipulation_bot impor
def test_data_manipulation(uc2_network):
"""Tests the UC2 data manipulation scenario end-to-end. Is a work in progress."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot")
database_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = database_server.software_manager.software["DatabaseService"]
db_service: DatabaseService = database_server.software_manager.software.get("DatabaseService")
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_service.backup_database()

View File

@@ -16,3 +16,9 @@ def test_link_up():
assert nic_a.enabled
assert nic_b.enabled
assert link.is_up
def test_ping_between_computer_and_server(client_server):
computer, server = client_server
assert computer.ping(target_ip_address=server.nics[next(iter(server.nics))].ip_address)

View File

@@ -2,6 +2,28 @@ import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.networks import client_server_routed
def test_network(example_network):
network: Network = example_network
client_1: Computer = network.get_node_by_hostname("client_1")
client_2: Computer = network.get_node_by_hostname("client_2")
server_1: Server = network.get_node_by_hostname("server_1")
server_2: Server = network.get_node_by_hostname("server_2")
assert client_1.ping(client_2.ethernet_port[1].ip_address)
assert client_2.ping(client_1.ethernet_port[1].ip_address)
assert server_1.ping(server_2.ethernet_port[1].ip_address)
assert server_2.ping(server_1.ethernet_port[1].ip_address)
assert client_1.ping(server_1.ethernet_port[1].ip_address)
assert client_2.ping(server_1.ethernet_port[1].ip_address)
assert client_1.ping(server_2.ethernet_port[1].ip_address)
assert client_2.ping(server_2.ethernet_port[1].ip_address)
def test_adding_removing_nodes():

View File

@@ -0,0 +1,121 @@
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
@pytest.fixture(scope="function")
def populated_node(application_class) -> Tuple[Application, Computer]:
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
computer.software_manager.install(application_class)
app = computer.software_manager.software.get("TestApplication")
app.run()
return app, computer
def test_service_on_offline_node(application_class):
"""Test to check that the service cannot be interacted with when node it is on is off."""
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
computer.software_manager.install(application_class)
app: Application = computer.software_manager.software.get("TestApplication")
computer.power_off()
for i in range(computer.shut_down_duration + 1):
computer.apply_timestep(timestep=i)
assert computer.operating_state is NodeOperatingState.OFF
assert app.operating_state is ApplicationOperatingState.CLOSED
app.run()
assert app.operating_state is ApplicationOperatingState.CLOSED
def test_server_turns_off_service(populated_node):
"""Check that the service is turned off when the server is turned off"""
app, computer = populated_node
assert computer.operating_state is NodeOperatingState.ON
assert app.operating_state is ApplicationOperatingState.RUNNING
computer.power_off()
for i in range(computer.shut_down_duration + 1):
computer.apply_timestep(timestep=i)
assert computer.operating_state is NodeOperatingState.OFF
assert app.operating_state is ApplicationOperatingState.CLOSED
def test_service_cannot_be_turned_on_when_server_is_off(populated_node):
"""Check that the service cannot be started when the server is off."""
app, computer = populated_node
assert computer.operating_state is NodeOperatingState.ON
assert app.operating_state is ApplicationOperatingState.RUNNING
computer.power_off()
for i in range(computer.shut_down_duration + 1):
computer.apply_timestep(timestep=i)
assert computer.operating_state is NodeOperatingState.OFF
assert app.operating_state is ApplicationOperatingState.CLOSED
app.run()
assert computer.operating_state is NodeOperatingState.OFF
assert app.operating_state is ApplicationOperatingState.CLOSED
def test_server_turns_on_service(populated_node):
"""Check that turning on the server turns on service."""
app, computer = populated_node
assert computer.operating_state is NodeOperatingState.ON
assert app.operating_state is ApplicationOperatingState.RUNNING
computer.power_off()
for i in range(computer.shut_down_duration + 1):
computer.apply_timestep(timestep=i)
assert computer.operating_state is NodeOperatingState.OFF
assert app.operating_state is ApplicationOperatingState.CLOSED
computer.power_on()
for i in range(computer.start_up_duration + 1):
computer.apply_timestep(timestep=i)
assert computer.operating_state is NodeOperatingState.ON
assert app.operating_state is ApplicationOperatingState.RUNNING
computer.start_up_duration = 0
computer.shut_down_duration = 0
computer.power_off()
assert computer.operating_state is NodeOperatingState.OFF
assert app.operating_state is ApplicationOperatingState.CLOSED
computer.power_on()
assert computer.operating_state is NodeOperatingState.ON
assert app.operating_state is ApplicationOperatingState.RUNNING

View File

@@ -1,17 +1,19 @@
from ipaddress import IPv4Address
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
def test_database_client_server_connection(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
assert len(db_service.connections) == 1
@@ -21,10 +23,10 @@ def test_database_client_server_connection(uc2_network):
def test_database_client_server_correct_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_client.disconnect()
@@ -38,10 +40,10 @@ def test_database_client_server_correct_password(uc2_network):
def test_database_client_server_incorrect_password(uc2_network):
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_client.disconnect()
db_client.configure(server_ip_address=IPv4Address("192.168.1.14"), server_password="54321")
@@ -54,8 +56,9 @@ def test_database_client_server_incorrect_password(uc2_network):
def test_database_client_query(uc2_network):
"""Tests DB query across the network returns HTTP status 200 and date."""
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
assert db_client.connected
assert db_client.query("SELECT")
@@ -63,13 +66,13 @@ def test_database_client_query(uc2_network):
def test_create_database_backup(uc2_network):
"""Run the backup_database method and check if the FTP server has the relevant file."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
# back up should be created
assert db_service.backup_database() is True
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"]
ftp_server: FTPServer = backup_server.software_manager.software.get("FTPServer")
# backup file should exist in the backup server
assert ftp_server.file_system.get_file(folder_name=db_service.uuid, file_name="database.db") is not None
@@ -78,7 +81,7 @@ def test_create_database_backup(uc2_network):
def test_restore_backup(uc2_network):
"""Run the restore_backup method and check if the backup is properly restored."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software["DatabaseService"]
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
# create a back up
assert db_service.backup_database() is True
@@ -92,3 +95,28 @@ def test_restore_backup(uc2_network):
assert db_service.restore_backup() is True
assert db_service.file_system.get_file(folder_name="database", file_name="database.db") is not None
def test_database_client_cannot_query_offline_database_server(uc2_network):
"""Tests DB query across the network returns HTTP status 404 when db server is offline."""
db_server: Server = uc2_network.get_node_by_hostname("database_server")
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
assert db_server.operating_state is NodeOperatingState.ON
assert db_service.operating_state is ServiceOperatingState.RUNNING
web_server: Server = uc2_network.get_node_by_hostname("web_server")
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
assert db_client.connected
assert db_client.query("SELECT") is True
db_server.power_off()
for i in range(db_server.shut_down_duration + 1):
uc2_network.apply_timestep(timestep=i)
assert db_server.operating_state is NodeOperatingState.OFF
assert db_service.operating_state is ServiceOperatingState.STOPPED
assert db_client.query("SELECT") is False

View File

@@ -1,3 +1,9 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.services.dns.dns_client import DNSClient
@@ -5,12 +11,31 @@ from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.service import ServiceOperatingState
def test_dns_client_server(uc2_network):
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
domain_controller: Server = uc2_network.get_node_by_hostname("domain_controller")
@pytest.fixture(scope="function")
def dns_client_and_dns_server(client_server) -> Tuple[DNSClient, Computer, DNSServer, Server]:
computer, server = client_server
dns_client: DNSClient = client_1.software_manager.software["DNSClient"]
dns_server: DNSServer = domain_controller.software_manager.software["DNSServer"]
# Install DNS Client on computer
computer.software_manager.install(DNSClient)
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
dns_client.start()
# set server as DNS Server
dns_client.dns_server = IPv4Address(server.nics.get(next(iter(server.nics))).ip_address)
# Install DNS Server on server
server.software_manager.install(DNSServer)
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
dns_server.start()
# register arcd.com as a domain
dns_server.dns_register(
domain_name="arcd.com", domain_ip_address=IPv4Address(server.nics.get(next(iter(server.nics))).ip_address)
)
return dns_client, computer, dns_server, server
def test_dns_client_server(dns_client_and_dns_server):
dns_client, computer, dns_server, server = dns_client_and_dns_server
assert dns_client.operating_state == ServiceOperatingState.RUNNING
assert dns_server.operating_state == ServiceOperatingState.RUNNING
@@ -24,3 +49,35 @@ def test_dns_client_server(uc2_network):
# arcd.com is registered in dns server and should be saved to cache
assert dns_client.check_domain_exists(target_domain="arcd.com")
assert dns_client.dns_cache.get("arcd.com", None) is not None
assert len(dns_client.dns_cache) == 1
def test_dns_client_requests_offline_dns_server(dns_client_and_dns_server):
dns_client, computer, dns_server, server = dns_client_and_dns_server
assert dns_client.operating_state == ServiceOperatingState.RUNNING
assert dns_server.operating_state == ServiceOperatingState.RUNNING
dns_server.show()
# arcd.com is registered in dns server
assert dns_client.check_domain_exists(target_domain="arcd.com")
assert dns_client.dns_cache.get("arcd.com", None) is not None
assert len(dns_client.dns_cache) == 1
dns_client.dns_cache = {}
server.power_off()
for i in range(server.shut_down_duration + 1):
server.apply_timestep(timestep=i)
assert server.operating_state == NodeOperatingState.OFF
assert dns_server.operating_state == ServiceOperatingState.STOPPED
# this time it should not cache because dns server is not online
assert dns_client.check_domain_exists(target_domain="arcd.com") is False
assert dns_client.dns_cache.get("arcd.com", None) is None
assert len(dns_client.dns_cache) == 0

View File

@@ -1,4 +1,7 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
@@ -7,15 +10,28 @@ from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
def test_ftp_client_store_file_in_server(uc2_network):
@pytest.fixture(scope="function")
def ftp_client_and_ftp_server(client_server) -> Tuple[FTPClient, Computer, FTPServer, Server]:
computer, server = client_server
# Install FTP Client service on computer
computer.software_manager.install(FTPClient)
ftp_client: FTPClient = computer.software_manager.software.get("FTPClient")
ftp_client.start()
# Install FTP Server service on server
server.software_manager.install(FTPServer)
ftp_server: FTPServer = server.software_manager.software.get("FTPServer")
ftp_server.start()
return ftp_client, computer, ftp_server, server
def test_ftp_client_store_file_in_server(ftp_client_and_ftp_server):
"""
Test checks to see if the client is able to store files in the backup server.
"""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_client: FTPClient = client_1.software_manager.software["FTPClient"]
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"]
ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
assert ftp_client.operating_state == ServiceOperatingState.RUNNING
assert ftp_server.operating_state == ServiceOperatingState.RUNNING
@@ -28,21 +44,17 @@ def test_ftp_client_store_file_in_server(uc2_network):
src_file_name="test_file.txt",
dest_folder_name="client_1_backup",
dest_file_name="test_file.txt",
dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address,
dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
)
assert ftp_server.file_system.get_file(folder_name="client_1_backup", file_name="test_file.txt")
def test_ftp_client_retrieve_file_from_server(uc2_network):
def test_ftp_client_retrieve_file_from_server(ftp_client_and_ftp_server):
"""
Test checks to see if the client is able to retrieve files from the backup server.
"""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
backup_server: Server = uc2_network.get_node_by_hostname("backup_server")
ftp_client: FTPClient = client_1.software_manager.software["FTPClient"]
ftp_server: FTPServer = backup_server.software_manager.software["FTPServer"]
ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
assert ftp_client.operating_state == ServiceOperatingState.RUNNING
assert ftp_server.operating_state == ServiceOperatingState.RUNNING
@@ -55,8 +67,41 @@ def test_ftp_client_retrieve_file_from_server(uc2_network):
src_file_name="test_file.txt",
dest_folder_name="downloads",
dest_file_name="test_file.txt",
dest_ip_address=backup_server.nics.get(next(iter(backup_server.nics))).ip_address,
dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
)
# client should have retrieved the file
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="test_file.txt")
def test_ftp_client_tries_to_connect_to_offline_server(ftp_client_and_ftp_server):
"""Test checks to make sure that the client can't do anything when the server is offline."""
ftp_client, computer, ftp_server, server = ftp_client_and_ftp_server
assert ftp_client.operating_state == ServiceOperatingState.RUNNING
assert ftp_server.operating_state == ServiceOperatingState.RUNNING
# create file on ftp server
ftp_server.file_system.create_file(file_name="test_file.txt", folder_name="file_share")
server.power_off()
for i in range(server.shut_down_duration + 1):
server.apply_timestep(timestep=i)
assert ftp_client.operating_state == ServiceOperatingState.RUNNING
assert ftp_server.operating_state == ServiceOperatingState.STOPPED
assert (
ftp_client.request_file(
src_folder_name="file_share",
src_file_name="test_file.txt",
dest_folder_name="downloads",
dest_file_name="test_file.txt",
dest_ip_address=server.nics.get(next(iter(server.nics))).ip_address,
)
is False
)
# client should have retrieved the file
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="test_file.txt") is None

View File

@@ -0,0 +1,129 @@
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.system.services.service import Service, ServiceOperatingState
@pytest.fixture(scope="function")
def populated_node(
service_class,
) -> Tuple[Server, Service]:
server = Server(
hostname="server", ip_address="192.168.0.1", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON
)
server.software_manager.install(service_class)
service = server.software_manager.software.get("TestService")
service.start()
return server, service
def test_service_on_offline_node(service_class):
"""Test to check that the service cannot be interacted with when node it is on is off."""
computer: Computer = Computer(
hostname="test_computer",
ip_address="192.168.1.2",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
computer.software_manager.install(service_class)
service: Service = computer.software_manager.software.get("TestService")
computer.power_off()
for i in range(computer.shut_down_duration + 1):
computer.apply_timestep(timestep=i)
assert computer.operating_state is NodeOperatingState.OFF
assert service.operating_state is ServiceOperatingState.STOPPED
service.start()
assert service.operating_state is ServiceOperatingState.STOPPED
service.resume()
assert service.operating_state is ServiceOperatingState.STOPPED
service.restart()
assert service.operating_state is ServiceOperatingState.STOPPED
service.pause()
assert service.operating_state is ServiceOperatingState.STOPPED
def test_server_turns_off_service(populated_node):
"""Check that the service is turned off when the server is turned off"""
server, service = populated_node
assert server.operating_state is NodeOperatingState.ON
assert service.operating_state is ServiceOperatingState.RUNNING
server.power_off()
for i in range(server.shut_down_duration + 1):
server.apply_timestep(timestep=i)
assert server.operating_state is NodeOperatingState.OFF
assert service.operating_state is ServiceOperatingState.STOPPED
def test_service_cannot_be_turned_on_when_server_is_off(populated_node):
"""Check that the service cannot be started when the server is off."""
server, service = populated_node
assert server.operating_state is NodeOperatingState.ON
assert service.operating_state is ServiceOperatingState.RUNNING
server.power_off()
for i in range(server.shut_down_duration + 1):
server.apply_timestep(timestep=i)
assert server.operating_state is NodeOperatingState.OFF
assert service.operating_state is ServiceOperatingState.STOPPED
service.start()
assert server.operating_state is NodeOperatingState.OFF
assert service.operating_state is ServiceOperatingState.STOPPED
def test_server_turns_on_service(populated_node):
"""Check that turning on the server turns on service."""
server, service = populated_node
assert server.operating_state is NodeOperatingState.ON
assert service.operating_state is ServiceOperatingState.RUNNING
server.power_off()
for i in range(server.shut_down_duration + 1):
server.apply_timestep(timestep=i)
assert server.operating_state is NodeOperatingState.OFF
assert service.operating_state is ServiceOperatingState.STOPPED
server.power_on()
for i in range(server.start_up_duration + 1):
server.apply_timestep(timestep=i)
assert server.operating_state is NodeOperatingState.ON
assert service.operating_state is ServiceOperatingState.RUNNING
server.start_up_duration = 0
server.shut_down_duration = 0
server.power_off()
assert server.operating_state is NodeOperatingState.OFF
assert service.operating_state is ServiceOperatingState.STOPPED
server.power_on()
assert server.operating_state is NodeOperatingState.ON
assert service.operating_state is ServiceOperatingState.RUNNING

View File

@@ -1,52 +1,118 @@
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.http import HttpStatusCode
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.service import ServiceOperatingState
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.web_server.web_server import WebServer
def test_web_page_home_page(uc2_network):
"""Test to see if the browser is able to open the main page of the web server."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
@pytest.fixture(scope="function")
def web_client_and_web_server(client_server) -> Tuple[WebBrowser, Computer, WebServer, Server]:
computer, server = client_server
assert web_client.get_webpage("http://arcd.com/") is True
# Install Web Browser on computer
computer.software_manager.install(WebBrowser)
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.run()
# latest reponse should have status code 200
assert web_client.latest_response is not None
assert web_client.latest_response.status_code == HttpStatusCode.OK
# Install DNS Client service on computer
computer.software_manager.install(DNSClient)
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
# set dns server
dns_client.dns_server = server.nics[next(iter(server.nics))].ip_address
# Install Web Server service on server
server.software_manager.install(WebServer)
web_server_service: WebServer = server.software_manager.software.get("WebServer")
web_server_service.start()
# Install DNS Server service on server
server.software_manager.install(DNSServer)
dns_server: DNSServer = server.software_manager.software.get("DNSServer")
# register arcd.com to DNS
dns_server.dns_register(domain_name="arcd.com", domain_ip_address=server.nics[next(iter(server.nics))].ip_address)
return web_browser, computer, web_server_service, server
def test_web_page_get_users_page_request_with_domain_name(uc2_network):
def test_web_page_get_users_page_request_with_domain_name(web_client_and_web_server):
"""Test to see if the client can handle requests with domain names"""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
web_browser_app, computer, web_server_service, server = web_client_and_web_server
assert web_client.get_webpage("http://arcd.com/users/") is True
web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
# latest reponse should have status code 200
assert web_client.latest_response is not None
assert web_client.latest_response.status_code == HttpStatusCode.OK
assert web_browser_app.get_webpage() is True
# latest response should have status code 200
assert web_browser_app.latest_response is not None
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
def test_web_page_get_users_page_request_with_ip_address(uc2_network):
def test_web_page_get_users_page_request_with_ip_address(web_client_and_web_server):
"""Test to see if the client can handle requests that use ip_address."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.run()
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server: Server = uc2_network.get_node_by_hostname("web_server")
web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address
web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
web_browser_app.target_url = f"http://{web_server_ip}/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
assert web_client.get_webpage(f"http://{web_server_ip}/users/") is True
# latest response should have status code 200
assert web_browser_app.latest_response is not None
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
# latest reponse should have status code 200
assert web_client.latest_response is not None
assert web_client.latest_response.status_code == HttpStatusCode.OK
def test_web_page_request_from_shut_down_server(web_client_and_web_server):
"""Test to see that the web server does not respond when the server is off."""
web_browser_app, computer, web_server_service, server = web_client_and_web_server
web_server_ip = server.nics.get(next(iter(server.nics))).ip_address
web_browser_app.target_url = f"http://arcd.com/"
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
assert web_browser_app.get_webpage() is True
# latest response should have status code 200
assert web_browser_app.latest_response is not None
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
server.power_off()
server.power_off()
for i in range(server.shut_down_duration + 1):
server.apply_timestep(timestep=i)
# node should be off
assert server.operating_state is NodeOperatingState.OFF
assert web_browser_app.get_webpage() is False
assert web_browser_app.latest_response.status_code == HttpStatusCode.NOT_FOUND
def test_web_page_request_from_closed_web_browser(web_client_and_web_server):
web_browser_app, computer, web_server_service, server = web_client_and_web_server
assert web_browser_app.operating_state == ApplicationOperatingState.RUNNING
web_browser_app.target_url = f"http://arcd.com/"
assert web_browser_app.get_webpage() is True
# latest response should have status code 200
assert web_browser_app.latest_response.status_code == HttpStatusCode.OK
web_browser_app.close()
# node should be off
assert web_browser_app.operating_state is ApplicationOperatingState.CLOSED
assert web_browser_app.get_webpage() is False

View File

@@ -0,0 +1,108 @@
from ipaddress import IPv4Address
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.base import Link
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.applications.web_browser import WebBrowser
from primaite.simulator.system.services.database.database_service import DatabaseService
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
from primaite.simulator.system.services.web_server.web_server import WebServer
@pytest.fixture(scope="function")
def web_client_web_server_database(example_network) -> Tuple[Computer, Server, Server]:
# add rules to network router
router_1: Router = example_network.get_node_by_hostname("router_1")
router_1.acl.add_rule(
action=ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=0
)
# Allow DNS requests
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.DNS, dst_port=Port.DNS, position=1)
# Allow FTP requests
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.FTP, dst_port=Port.FTP, position=2)
# Open port 80 for web server
router_1.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.HTTP, dst_port=Port.HTTP, position=3)
# Create Computer
computer: Computer = example_network.get_node_by_hostname("client_1")
# Create Web Server
web_server: Server = example_network.get_node_by_hostname("server_1")
# Create Database Server
db_server = example_network.get_node_by_hostname("server_2")
# Get the NICs
computer_nic = computer.nics[next(iter(computer.nics))]
server_nic = web_server.nics[next(iter(web_server.nics))]
db_server_nic = db_server.nics[next(iter(db_server.nics))]
# Connect Computer and Server
link_computer_server = Link(endpoint_a=computer_nic, endpoint_b=server_nic)
# Should be linked
assert link_computer_server.is_up
# Connect database server and web server
link_server_db = Link(endpoint_a=server_nic, endpoint_b=db_server_nic)
# Should be linked
assert link_computer_server.is_up
assert link_server_db.is_up
# Install DatabaseService on db server
db_server.software_manager.install(DatabaseService)
db_service: DatabaseService = db_server.software_manager.software.get("DatabaseService")
db_service.start()
# Install Web Browser on computer
computer.software_manager.install(WebBrowser)
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.target_url = "http://arcd.com/users/"
web_browser.run()
# Install DNS Client service on computer
computer.software_manager.install(DNSClient)
dns_client: DNSClient = computer.software_manager.software.get("DNSClient")
# set dns server
dns_client.dns_server = web_server.nics[next(iter(web_server.nics))].ip_address
# Install Web Server service on web server
web_server.software_manager.install(WebServer)
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
web_server_service.start()
# Install DNS Server service on web server
web_server.software_manager.install(DNSServer)
dns_server: DNSServer = web_server.software_manager.software.get("DNSServer")
# register arcd.com to DNS
dns_server.dns_register(
domain_name="arcd.com", domain_ip_address=web_server.nics[next(iter(web_server.nics))].ip_address
)
# Install DatabaseClient service on web server
web_server.software_manager.install(DatabaseClient)
db_client: DatabaseClient = web_server.software_manager.software.get("DatabaseClient")
db_client.server_ip_address = IPv4Address(db_server_nic.ip_address) # set IP address of Database Server
db_client.run()
assert dns_client.check_domain_exists("arcd.com")
assert db_client.connect()
return computer, web_server, db_server
def test_web_client_requests_users(web_client_web_server_database):
computer, web_server, db_server = web_client_web_server_database
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
assert web_browser.get_webpage()

View File

@@ -1,18 +1,140 @@
"""Test the account module of the simulator."""
import pytest
from primaite.simulator.domain.account import Account, AccountType
def test_account_serialise():
@pytest.fixture(scope="function")
def account() -> Account:
acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER)
acct.set_original_state()
return acct
def test_original_state(account):
"""Test the original state - see if it resets properly"""
account.log_on()
account.log_off()
account.disable()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False
account.reset_component_for_episode(episode=1)
state = account.describe_state()
assert state["num_logons"] is 0
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_on()
account.log_off()
account.disable()
account.set_original_state()
account.log_on()
state = account.describe_state()
assert state["num_logons"] is 2
account.reset_component_for_episode(episode=2)
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False
def test_enable(account):
"""Should enable the account."""
account.enabled = False
account.enable()
assert account.enabled is True
def test_disable(account):
"""Should disable the account."""
account.enabled = True
account.disable()
assert account.enabled is False
def test_log_on_increments(account):
"""Should increase the log on value by 1."""
account.num_logons = 0
account.log_on()
assert account.num_logons is 1
def test_log_off_increments(account):
"""Should increase the log on value by 1."""
account.num_logoffs = 0
account.log_off()
assert account.num_logoffs is 1
def test_account_serialise(account):
"""Test that an account can be serialised. If pydantic throws error then this test fails."""
acct = Account(username="Jake", password="JakePass1!", account_type=AccountType.USER)
serialised = acct.model_dump_json()
serialised = account.model_dump_json()
print(serialised)
def test_account_deserialise():
def test_account_deserialise(account):
"""Test that an account can be deserialised. The test fails if pydantic throws an error."""
acct_json = (
'{"uuid":"dfb2bcaa-d3a1-48fd-af3f-c943354622b4","num_logons":0,"num_logoffs":0,"num_group_changes":0,'
'"username":"Jake","password":"JakePass1!","account_type":2,"status":2,"request_manager":null}'
'"username":"Jake","password":"totally_hashed_password","account_type":2,"status":2,"request_manager":null}'
)
acct = Account.model_validate_json(acct_json)
assert Account.model_validate_json(acct_json)
def test_describe_state(account):
state = account.describe_state()
assert state["num_logons"] is 0
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_on()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 0
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.log_off()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is True
account.disable()
state = account.describe_state()
assert state["num_logons"] is 1
assert state["num_logoffs"] is 1
assert state["num_group_changes"] is 0
assert state["username"] is "Jake"
assert state["password"] is "totally_hashed_password"
assert state["account_type"] is AccountType.USER.value
assert state["enabled"] is False

View File

@@ -185,6 +185,38 @@ def test_get_file(file_system):
file_system.show(full=True)
def test_reset_file_system(file_system):
# file and folder that existed originally
file_system.create_file(file_name="test_file.zip")
file_system.create_folder(folder_name="test_folder")
file_system.set_original_state()
# create a new file
file_system.create_file(file_name="new_file.txt")
# create a new folder
file_system.create_folder(folder_name="new_folder")
# delete the file that existed originally
file_system.delete_file(folder_name="root", file_name="test_file.zip")
assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None
# delete the folder that existed originally
file_system.delete_folder(folder_name="test_folder")
assert file_system.get_folder(folder_name="test_folder") is None
# reset
file_system.reset_component_for_episode(episode=1)
# deleted original file and folder should be back
assert file_system.get_file(folder_name="root", file_name="test_file.zip")
assert file_system.get_folder(folder_name="test_folder")
# new file and folder should be removed
assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None
assert file_system.get_folder(folder_name="new_folder") is None
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_serialisation(file_system):
"""Test to check that the object serialisation works correctly."""

View File

@@ -0,0 +1,17 @@
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.switch import Switch
@pytest.fixture(scope="function")
def switch() -> Switch:
switch: Switch = Switch(hostname="switch_1", num_ports=8, operating_state=NodeOperatingState.ON)
switch.show()
return switch
def test_describe_state(switch):
state = switch.describe_state()
assert len(state.get("ports")) is 8
assert state.get("num_ports") is 8

View File

@@ -3,6 +3,66 @@ import json
import pytest
from primaite.simulator.network.container import Network
from primaite.simulator.network.hardware.base import Link, Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.system.applications.database_client import DatabaseClient
from primaite.simulator.system.services.database.database_service import DatabaseService
@pytest.fixture(scope="function")
def network(example_network) -> Network:
assert len(example_network.routers) is 1
assert len(example_network.switches) is 2
assert len(example_network.computers) is 2
assert len(example_network.servers) is 2
example_network.set_original_state()
example_network.show()
return example_network
def test_describe_state(network):
"""Test that describe state works."""
state = network.describe_state()
assert len(state["nodes"]) is 7
assert len(state["links"]) is 6
def test_reset_network(network):
"""
Test that the network is properly reset.
TODO: make sure that once implemented - any installed/uninstalled services, processes, apps,
etc are also removed/reinstalled
"""
state_before = network.describe_state()
client_1: Computer = network.get_node_by_hostname("client_1")
server_1: Computer = network.get_node_by_hostname("server_1")
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
client_1.power_off()
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
server_1.power_off()
assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN
assert network.describe_state() != state_before
network.reset_component_for_episode(episode=1)
assert client_1.operating_state is NodeOperatingState.ON
assert server_1.operating_state is NodeOperatingState.ON
assert json.dumps(network.describe_state(), sort_keys=True, indent=2) == json.dumps(
state_before, sort_keys=True, indent=2
)
def test_creating_container():
@@ -10,11 +70,50 @@ def test_creating_container():
net = Network()
assert net.nodes == {}
assert net.links == {}
net.show()
@pytest.mark.skip(reason="Skipping until we tackle serialisation")
def test_describe_state():
"""Check that we can describe network state without raising errors, and that the result is JSON serialisable."""
net = Network()
state = net.describe_state()
json.dumps(state) # if this function call raises an error, the test fails, state was not JSON-serialisable
def test_apply_timestep_to_nodes(network):
"""Calling apply_timestep on the network should apply to the nodes within it."""
client_1: Computer = network.get_node_by_hostname("client_1")
assert client_1.operating_state is NodeOperatingState.ON
client_1.power_off()
assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN
for i in range(client_1.shut_down_duration + 1):
network.apply_timestep(timestep=i)
assert client_1.operating_state is NodeOperatingState.OFF
network.apply_timestep(client_1.shut_down_duration + 2)
assert client_1.operating_state is NodeOperatingState.OFF
def test_removing_node_that_does_not_exist(network):
"""Node that does not exist on network should not affect existing nodes."""
assert len(network.nodes) is 7
network.remove_node(Node(hostname="new_node"))
assert len(network.nodes) is 7
def test_remove_node(network):
"""Remove node should remove the correct node."""
assert len(network.nodes) is 7
client_1: Computer = network.get_node_by_hostname("client_1")
network.remove_node(client_1)
assert network.get_node_by_hostname("client_1") is None
assert len(network.nodes) is 6
def test_remove_link(network):
"""Remove link should remove the correct link."""
assert len(network.links) is 6
link: Link = network.links.get(next(iter(network.links)))
network.remove_link(link)
assert len(network.links) is 5
assert network.links.get(link.uuid) is None

View File

@@ -0,0 +1,11 @@
from primaite.simulator.network.utils import convert_bytes_to_megabits, convert_megabits_to_bytes
def test_convert_bytes_to_megabits():
assert round(convert_bytes_to_megabits(B=131072), 5) == float(1)
assert round(convert_bytes_to_megabits(B=69420), 5) == float(0.52963)
def test_convert_megabits_to_bytes():
assert round(convert_megabits_to_bytes(Mbits=1), 5) == float(131072)
assert round(convert_megabits_to_bytes(Mbits=float(0.52963)), 5) == float(69419.66336)

View File

@@ -0,0 +1,122 @@
from ipaddress import IPv4Address
from typing import Tuple, Union
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.database_client import DatabaseClient
@pytest.fixture(scope="function")
def database_client_on_computer() -> Tuple[DatabaseClient, Computer]:
computer = Computer(
hostname="db_node", ip_address="192.168.0.1", subnet_mask="255.255.255.0", operating_state=NodeOperatingState.ON
)
computer.software_manager.install(DatabaseClient)
database_client: DatabaseClient = computer.software_manager.software.get("DatabaseClient")
database_client.configure(server_ip_address=IPv4Address("192.168.0.1"))
database_client.run()
return database_client, computer
def test_creation(database_client_on_computer):
database_client, computer = database_client_on_computer
database_client.describe_state()
def test_connect_when_client_is_closed(database_client_on_computer):
"""Database client should not connect when it is not running."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
assert database_client.connect() is False
def test_connect_to_database_fails_on_reattempt(database_client_on_computer):
"""Database client should return False when the attempt to connect fails."""
database_client, computer = database_client_on_computer
database_client.connected = False
assert database_client._connect(server_ip_address=IPv4Address("192.168.0.1"), is_reattempt=True) is False
def test_disconnect_when_client_is_closed(database_client_on_computer):
"""Database client disconnect should not do anything when it is not running."""
database_client, computer = database_client_on_computer
database_client.connected = True
assert database_client.server_ip_address is not None
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
database_client.disconnect()
assert database_client.connected is True
assert database_client.server_ip_address is not None
def test_disconnect(database_client_on_computer):
"""Database client should set connected to False and remove the database server ip address."""
database_client, computer = database_client_on_computer
database_client.connected = True
assert database_client.operating_state is ApplicationOperatingState.RUNNING
assert database_client.server_ip_address is not None
database_client.disconnect()
assert database_client.connected is False
assert database_client.server_ip_address is None
def test_query_when_client_is_closed(database_client_on_computer):
"""Database client should return False when it is not running."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
assert database_client.query(sql="test") is False
def test_query_failed_reattempt(database_client_on_computer):
"""Database client query should return False if the reattempt fails."""
database_client, computer = database_client_on_computer
def return_false():
return False
database_client.connect = return_false
database_client.connected = False
assert database_client.query(sql="test", is_reattempt=True) is False
def test_query_fail_to_connect(database_client_on_computer):
"""Database client query should return False if the connect attempt fails."""
database_client, computer = database_client_on_computer
def return_false():
return False
database_client.connect = return_false
database_client.connected = False
assert database_client.query(sql="test") is False
def test_client_receives_response_when_closed(database_client_on_computer):
"""Database client receive should return False when it is closed."""
database_client, computer = database_client_on_computer
database_client.close()
assert database_client.operating_state is ApplicationOperatingState.CLOSED
database_client.receive(payload={}, session_id="")

View File

@@ -1,39 +1,66 @@
from typing import Tuple
import pytest
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.protocols.http import HttpResponsePacket, HttpStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.applications.web_browser import WebBrowser
@pytest.fixture(scope="function")
def web_client() -> Computer:
node = Computer(
hostname="web_client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
def web_browser() -> WebBrowser:
computer = Computer(
hostname="web_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
return node
# Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
web_browser.run()
assert web_browser.operating_state is ApplicationOperatingState.RUNNING
return web_browser
def test_create_web_client(web_client):
assert web_client is not None
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"]
def test_create_web_client():
computer = Computer(
hostname="web_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
# Web Browser should be pre-installed in computer
web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser")
assert web_browser.name is "WebBrowser"
assert web_browser.port is Port.HTTP
assert web_browser.protocol is IPProtocol.TCP
def test_receive_invalid_payload(web_client):
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"]
def test_receive_invalid_payload(web_browser):
assert web_browser.receive(payload={}) is False
def test_receive_payload(web_client):
def test_receive_payload(web_browser):
payload = HttpResponsePacket(status_code=HttpStatusCode.OK)
web_browser: WebBrowser = web_client.software_manager.software["WebBrowser"]
assert web_browser.latest_response is None
web_browser.receive(payload=payload)
assert web_browser.latest_response is not None
def test_invalid_target_url(web_browser):
# none value target url
web_browser.target_url = None
assert web_browser.get_webpage() is False
def test_non_existent_target_url(web_browser):
web_browser.target_url = "http://192.168.255.255"
assert web_browser.get_webpage() is False

View File

@@ -1,20 +1,73 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.networks import arcd_uc2_network
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.red_services.data_manipulation_bot import DataManipulationBot
from primaite.simulator.system.applications.application import ApplicationOperatingState
from primaite.simulator.system.services.red_services.data_manipulation_bot import (
DataManipulationAttackStage,
DataManipulationBot,
)
def test_creation():
@pytest.fixture(scope="function")
def dm_client() -> Node:
network = arcd_uc2_network()
return network.get_node_by_hostname("client_1")
client_1: Node = network.get_node_by_hostname("client_1")
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
@pytest.fixture
def dm_bot(dm_client) -> DataManipulationBot:
return dm_client.software_manager.software.get("DataManipulationBot")
def test_create_dm_bot(dm_client):
data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot")
assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
assert data_manipulation_bot.protocol == IPProtocol.TCP
assert data_manipulation_bot.payload == "DELETE"
def test_dm_bot_logon(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.NOT_STARTED
dm_bot._logon()
assert dm_bot.attack_stage == DataManipulationAttackStage.LOGON
def test_dm_bot_perform_port_scan_no_success(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.LOGON
dm_bot._perform_port_scan(p_of_success=0.0)
assert dm_bot.attack_stage == DataManipulationAttackStage.LOGON
def test_dm_bot_perform_port_scan_success(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.LOGON
dm_bot._perform_port_scan(p_of_success=1.0)
assert dm_bot.attack_stage == DataManipulationAttackStage.PORT_SCAN
def test_dm_bot_perform_data_manipulation_no_success(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN
dm_bot._perform_data_manipulation(p_of_success=0.0)
assert dm_bot.attack_stage == DataManipulationAttackStage.PORT_SCAN
def test_dm_bot_perform_data_manipulation_success(dm_bot):
dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN
dm_bot.operating_state = ApplicationOperatingState.RUNNING
dm_bot._perform_data_manipulation(p_of_success=1.0)
assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED)
assert dm_bot.connected

View File

@@ -8,7 +8,7 @@ from primaite.simulator.system.services.database.database_service import Databas
def database_server() -> Node:
node = Node(hostname="db_node")
node.software_manager.install(DatabaseService)
node.software_manager.software["DatabaseService"].start()
node.software_manager.software.get("DatabaseService").start()
return node

View File

@@ -0,0 +1,104 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.service import ServiceOperatingState
@pytest.fixture(scope="function")
def dns_client() -> Node:
node = Computer(
hostname="dns_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
)
return node
def test_create_dns_client(dns_client):
assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client_service.name is "DNSClient"
assert dns_client_service.port is Port.DNS
assert dns_client_service.protocol is IPProtocol.TCP
def test_dns_client_add_domain_to_cache_when_not_running(dns_client):
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client.operating_state is NodeOperatingState.OFF
assert dns_client_service.operating_state is ServiceOperatingState.STOPPED
assert (
dns_client_service.add_domain_to_cache(domain_name="test.com", ip_address=IPv4Address("192.168.1.100")) is False
)
assert dns_client_service.dns_cache.get("test.com") is None
def test_dns_client_check_domain_exists_when_not_running(dns_client):
dns_client.operating_state = NodeOperatingState.ON
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.start()
assert dns_client.operating_state is NodeOperatingState.ON
assert dns_client_service.operating_state is ServiceOperatingState.RUNNING
assert (
dns_client_service.add_domain_to_cache(domain_name="test.com", ip_address=IPv4Address("192.168.1.100"))
is not False
)
assert dns_client_service.check_domain_exists("test.com") is True
dns_client.power_off()
for i in range(dns_client.shut_down_duration + 1):
dns_client.apply_timestep(timestep=i)
assert dns_client.operating_state is NodeOperatingState.OFF
assert dns_client_service.operating_state is ServiceOperatingState.STOPPED
assert dns_client_service.check_domain_exists("test.com") is False
def test_dns_client_check_domain_in_cache(dns_client):
"""Test to make sure that the check_domain_in_cache returns the correct values."""
dns_client.operating_state = NodeOperatingState.ON
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.start()
# add a domain to the dns client cache
dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12"))
assert dns_client_service.check_domain_exists("fake-domain.com") is False
assert dns_client_service.check_domain_exists("real-domain.com") is True
def test_dns_client_receive(dns_client):
"""Test to make sure the DNS Client knows how to deal with request responses."""
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
dns_client_service.receive(
payload=DNSPacket(
dns_request=DNSRequest(domain_name_request="real-domain.com"),
dns_reply=DNSReply(domain_name_ip_address=IPv4Address("192.168.1.12")),
)
)
# domain name should be saved to cache
assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")
def test_dns_client_receive_non_dns_payload(dns_client):
dns_client_service: DNSClient = dns_client.software_manager.software.get("DNSClient")
assert dns_client_service.receive(payload=None) is False

View File

@@ -3,56 +3,38 @@ from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.dns import DNSPacket, DNSReply, DNSRequest
from primaite.simulator.network.protocols.dns import DNSPacket, DNSRequest
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.dns.dns_client import DNSClient
from primaite.simulator.system.services.dns.dns_server import DNSServer
@pytest.fixture(scope="function")
def dns_server() -> Node:
node = Server(
hostname="dns_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
)
node.software_manager.install(software_class=DNSServer)
node.software_manager.software["DNSServer"].start()
return node
@pytest.fixture(scope="function")
def dns_client() -> Node:
node = Computer(
hostname="dns_client",
ip_address="192.168.1.11",
hostname="dns_server",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
dns_server=IPv4Address("192.168.1.10"),
operating_state=NodeOperatingState.ON,
)
node.software_manager.install(software_class=DNSServer)
return node
def test_create_dns_server(dns_server):
assert dns_server is not None
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
assert dns_server_service.name is "DNSServer"
assert dns_server_service.port is Port.DNS
assert dns_server_service.protocol is IPProtocol.TCP
def test_create_dns_client(dns_client):
assert dns_client is not None
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
assert dns_client_service.name is "DNSClient"
assert dns_client_service.port is Port.DNS
assert dns_client_service.protocol is IPProtocol.TCP
def test_dns_server_domain_name_registration(dns_server):
"""Test to check if the domain name registration works."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
@@ -62,20 +44,9 @@ def test_dns_server_domain_name_registration(dns_server):
assert dns_server_service.dns_lookup("real-domain.com") is not None
def test_dns_client_check_domain_in_cache(dns_client):
"""Test to make sure that the check_domain_in_cache returns the correct values."""
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
# add a domain to the dns client cache
dns_client_service.add_domain_to_cache("real-domain.com", IPv4Address("192.168.1.12"))
assert dns_client_service.check_domain_exists("fake-domain.com") is False
assert dns_client_service.check_domain_exists("real-domain.com") is True
def test_dns_server_receive(dns_server):
"""Test to make sure that the DNS Server correctly responds to a DNS Client request."""
dns_server_service: DNSServer = dns_server.software_manager.software["DNSServer"]
dns_server_service: DNSServer = dns_server.software_manager.software.get("DNSServer")
# register the web server in the domain controller
dns_server_service.dns_register(domain_name="real-domain.com", domain_ip_address=IPv4Address("192.168.1.12"))
@@ -91,18 +62,3 @@ def test_dns_server_receive(dns_server):
)
dns_server_service.show()
def test_dns_client_receive(dns_client):
"""Test to make sure the DNS Client knows how to deal with request responses."""
dns_client_service: DNSClient = dns_client.software_manager.software["DNSClient"]
dns_client_service.receive(
payload=DNSPacket(
dns_request=DNSRequest(domain_name_request="real-domain.com"),
dns_reply=DNSReply(domain_name_ip_address=IPv4Address("192.168.1.12")),
)
)
# domain name should be saved to cache
assert dns_client_service.dns_cache["real-domain.com"] == IPv4Address("192.168.1.12")

View File

@@ -0,0 +1,122 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.service import ServiceOperatingState
@pytest.fixture(scope="function")
def ftp_client() -> Node:
node = Computer(
hostname="ftp_client",
ip_address="192.168.1.11",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
return node
def test_create_ftp_client(ftp_client):
assert ftp_client is not None
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.name is "FTPClient"
assert ftp_client_service.port is Port.FTP
assert ftp_client_service.protocol is IPProtocol.TCP
def test_ftp_client_store_file(ftp_client):
"""Test to make sure the FTP Client knows how to deal with request responses."""
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") is None
response: FTPPacket = FTPPacket(
ftp_command=FTPCommand.STOR,
ftp_command_args={
"dest_folder_name": "downloads",
"dest_file_name": "file.txt",
"file_size": 24,
},
packet_payload_size=24,
status_code=FTPStatusCode.OK,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client_service.receive(response)
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt")
def test_ftp_should_not_process_commands_if_service_not_running(ftp_client):
"""Method _process_ftp_command should return false if service is not running."""
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=Port.FTP,
status_code=FTPStatusCode.OK,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client_service.stop()
assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED
assert ftp_client_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
def test_ftp_tries_to_senf_file__that_does_not_exist(ftp_client):
"""Method send_file should return false if no file to send."""
assert ftp_client.file_system.get_file(folder_name="root", file_name="test.txt") is None
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.operating_state is ServiceOperatingState.RUNNING
assert (
ftp_client_service.send_file(
dest_ip_address=IPv4Address("192.168.1.1"),
src_folder_name="root",
src_file_name="test.txt",
dest_folder_name="root",
dest_file_name="text.txt",
)
is False
)
def test_offline_ftp_client_receives_request(ftp_client):
"""Receive should return false if the node the ftp client is installed on is offline."""
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
ftp_client.power_off()
for i in range(ftp_client.shut_down_duration + 1):
ftp_client.apply_timestep(timestep=i)
assert ftp_client.operating_state is NodeOperatingState.OFF
assert ftp_client_service.operating_state is ServiceOperatingState.STOPPED
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=Port.FTP,
status_code=FTPStatusCode.OK,
)
assert ftp_client_service.receive(payload=payload) is False
def test_receive_should_fail_if_payload_is_not_ftp(ftp_client):
"""Receive should return false if the node the ftp client is installed on is not an FTPPacket."""
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.receive(payload=None) is False
def test_receive_should_ignore_payload_with_none_status_code(ftp_client):
"""Receive should ignore payload with no set status code to prevent infinite send/receive loops."""
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=Port.FTP,
status_code=None,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software.get("FTPClient")
assert ftp_client_service.receive(payload=payload) is False

View File

@@ -1,51 +1,36 @@
from ipaddress import IPv4Address
import pytest
from primaite.simulator.network.hardware.base import Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState
from primaite.simulator.network.hardware.nodes.server import Server
from primaite.simulator.network.protocols.ftp import FTPCommand, FTPPacket, FTPStatusCode
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.services.ftp.ftp_client import FTPClient
from primaite.simulator.system.services.ftp.ftp_server import FTPServer
from primaite.simulator.system.services.service import ServiceOperatingState
@pytest.fixture(scope="function")
def ftp_server() -> Node:
node = Server(
hostname="ftp_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
hostname="ftp_server",
ip_address="192.168.1.10",
subnet_mask="255.255.255.0",
default_gateway="192.168.1.1",
operating_state=NodeOperatingState.ON,
)
node.software_manager.install(software_class=FTPServer)
node.software_manager.software["FTPServer"].start()
return node
@pytest.fixture(scope="function")
def ftp_client() -> Node:
node = Computer(
hostname="ftp_client", ip_address="192.168.1.11", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
)
return node
def test_create_ftp_server(ftp_server):
assert ftp_server is not None
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service.name is "FTPServer"
assert ftp_server_service.port is Port.FTP
assert ftp_server_service.protocol is IPProtocol.TCP
def test_create_ftp_client(ftp_client):
assert ftp_client is not None
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
assert ftp_client_service.name is "FTPClient"
assert ftp_client_service.port is Port.FTP
assert ftp_client_service.protocol is IPProtocol.TCP
def test_ftp_server_store_file(ftp_server):
"""Test to make sure the FTP Server knows how to deal with request responses."""
assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt") is None
@@ -60,16 +45,34 @@ def test_ftp_server_store_file(ftp_server):
packet_payload_size=24,
)
ftp_server_service: FTPServer = ftp_server.software_manager.software["FTPServer"]
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
ftp_server_service.receive(response)
assert ftp_server.file_system.get_file(folder_name="downloads", file_name="file.txt")
def test_ftp_client_store_file(ftp_client):
"""Test to make sure the FTP Client knows how to deal with request responses."""
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt") is None
def test_ftp_server_should_send_error_if_port_arg_is_invalid(ftp_server):
"""Should fail if the port command receives an invalid port."""
payload: FTPPacket = FTPPacket(
ftp_command=FTPCommand.PORT,
ftp_command_args=None,
packet_payload_size=24,
)
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service._process_ftp_command(payload=payload).status_code is FTPStatusCode.ERROR
def test_ftp_server_receives_non_ftp_packet(ftp_server):
"""Receive should return false if the service receives a non ftp packet."""
response: FTPPacket = None
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
assert ftp_server_service.receive(response) is False
def test_offline_ftp_server_receives_request(ftp_server):
"""Receive should return false if the service is stopped."""
response: FTPPacket = FTPPacket(
ftp_command=FTPCommand.STOR,
ftp_command_args={
@@ -78,10 +81,9 @@ def test_ftp_client_store_file(ftp_client):
"file_size": 24,
},
packet_payload_size=24,
status_code=FTPStatusCode.OK,
)
ftp_client_service: FTPClient = ftp_client.software_manager.software["FTPClient"]
ftp_client_service.receive(response)
assert ftp_client.file_system.get_file(folder_name="downloads", file_name="file.txt")
ftp_server_service: FTPServer = ftp_server.software_manager.software.get("FTPServer")
ftp_server_service.stop()
assert ftp_server_service.operating_state is ServiceOperatingState.STOPPED
assert ftp_server_service.receive(response) is False

View File

@@ -18,13 +18,13 @@ def web_server() -> Server:
hostname="web_server", ip_address="192.168.1.10", subnet_mask="255.255.255.0", default_gateway="192.168.1.1"
)
node.software_manager.install(software_class=WebServer)
node.software_manager.software["WebServer"].start()
node.software_manager.software.get("WebServer").start()
return node
def test_create_web_server(web_server):
assert web_server is not None
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service.name is "WebServer"
assert web_server_service.port is Port.HTTP
assert web_server_service.protocol is IPProtocol.TCP
@@ -33,7 +33,7 @@ def test_create_web_server(web_server):
def test_handling_get_request_not_found_path(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/fake-path")
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload)
assert response.status_code == HttpStatusCode.NOT_FOUND
@@ -42,7 +42,7 @@ def test_handling_get_request_not_found_path(web_server):
def test_handling_get_request_home_page(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
response: HttpResponsePacket = web_server_service._handle_get_request(payload=payload)
assert response.status_code == HttpStatusCode.OK
@@ -51,7 +51,7 @@ def test_handling_get_request_home_page(web_server):
def test_process_http_request_get(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.GET, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service._process_http_request(payload=payload) is True
@@ -59,6 +59,6 @@ def test_process_http_request_get(web_server):
def test_process_http_request_method_not_allowed(web_server):
payload = HttpRequestPacket(request_method=HttpRequestMethod.DELETE, request_url="http://domain.com/")
web_server_service: WebServer = web_server.software_manager.software["WebServer"]
web_server_service: WebServer = web_server.software_manager.software.get("WebServer")
assert web_server_service._process_http_request(payload=payload) is False

Some files were not shown because too many files have changed in this diff Show More