diff --git a/CHANGELOG.md b/CHANGELOG.md index ce366d26..c01f0139 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,27 +6,35 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +- Made requests fail to reach their target if the node is off +- Added responses to requests +- Made environment reset completely recreate the game object. +- Changed the red agent in the data manipulation scenario to randomly choose client 1 or client 2 to start its attack. - Changed the data manipulation scenario to include a second green agent on client 1. - Refactored actions and observations to be configurable via object name, instead of UUID. -- Fixed a bug where ACL rules were not resetting on episode reset. -- Fixed a bug where blue agent's ACL actions were being applied against the wrong IP addresses -- Fixed a bug where deleted files and folders did not reset correctly on episode reset. -- Fixed a bug where service health status was using the actual health state instead of the visible health state -- Fixed a bug where the database file health status was using the incorrect value for negative rewards -- Fixed a bug preventing file actions from reaching their intended file - Made database patch correctly take 2 timesteps instead of being immediate - Made database patch only possible when the software is compromised or good, it's no longer possible when the software is OFF or RESETTING -- Temporarily disable the blue agent file delete action due to crashes. This issue is resolved in another branch that will be merged into dev soon. -- Fix a bug where ACLs were not showing up correctly in the observation space. - Added a notebook which explains Data manipulation scenario, demonstrates the attack, and shows off blue agent's action space, observation space, and reward function. - Made packet capture and system logging optional (off by default). To turn on, change the io_settings.save_pcap_logs and io_settings.save_sys_logs settings in the config. -- Made observation space flattening optional (on by default). To turn off for an agent, change the agent_settings.flatten_obs setting in the config. -- Fixed an issue where the data manipulation attack was triggered at episode start. -- Fixed a bug where FTP STOR stored an additional copy on the client machine's filesystem -- Fixed a bug where the red agent acted to early -- Fixed the order of service health state -- Fixed an issue where starting a node didn't start the services on it +- Made observation space flattening optional (on by default). To turn off for an agent, change the `agent_settings.flatten_obs` setting in the config. - Added support for SQL INSERT command. +- Added ability to log each agent's action choices in each step to a JSON file. + +### Bug Fixes + +- ACL rules were not resetting on episode reset. +- ACLs were not showing up correctly in the observation space. +- Blue agent's ACL actions were being applied against the wrong IP addresses +- Deleted files and folders did not reset correctly on episode reset. +- Service health status was using the actual health state instead of the visible health state +- Database file health status was using the incorrect value for negative rewards +- Preventing file actions from reaching their intended file +- The data manipulation attack was triggered at episode start. +- FTP STOR stored an additional copy on the client machine's filesystem +- The red agent acted to early +- Order of service health state +- Starting a node didn't start the services on it +- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off @@ -46,8 +54,12 @@ a Service/Application another machine. SessionManager. - Permission System - each action can define criteria that will be used to permit or deny agent actions. - File System - ability to emulate a node's file system during a simulation -- Example notebooks - There is currently 1 jupyter notebook which walks through using PrimAITE - 1. Creating a simulation - this notebook explains how to build up a simulation using the Python package. (WIP) +- Example notebooks - There are 5 jupyter notebook which walk through using PrimAITE + 1. Training a Stable Baselines 3 agent + 2. Training a single agent system using Ray RLLib + 3. Training a multi-agent system Ray RLLib + 4. Data manipulation end to end demonstration + 5. Data manipulation scenario with customised red agents - Database: - `DatabaseClient` and `DatabaseService` created to allow emulation of database actions - Ability for `DatabaseService` to backup its data to another server via FTP and restore data from backup @@ -57,7 +69,6 @@ SessionManager. - DNS Services: `DNSClient` and `DNSServer` - FTP Services: `FTPClient` and `FTPServer` - HTTP Services: `WebBrowser` to simulate a web client and `WebServer` -- Fixed an issue where the services were still able to run even though the node the service is installed on is turned off - NTP Services: `NTPClient` and `NTPServer` - **RouterNIC Class**: Introduced a new class `RouterNIC`, extending the standard `NIC` functionality. This class is specifically designed for router operations, optimizing the processing and routing of network traffic. - **Custom Layer-3 Processing**: The `RouterNIC` class includes custom handling for network frames, bypassing standard Node NIC's Layer 3 broadcast/unicast checks. This allows for more efficient routing behavior in network scenarios where router-specific frame processing is required. @@ -81,7 +92,21 @@ SessionManager. - `AirSpace` class to simulate wireless communications, managing wireless interfaces and facilitating the transmission of frames within specified frequencies. - `AirSpaceFrequency` enum for defining standard wireless frequencies, including 2.4 GHz and 5 GHz bands, to support realistic wireless network simulations. - `WirelessRouter` class, extending the `Router` class, to incorporate wireless networking capabilities alongside traditional wired connections. This class allows the configuration of wireless access points with specific IP settings and operating frequencies. - +- Documentation Updates: + - Examples include how to set up PrimAITE session via config + - Examples include how to create nodes and install software via config + - Examples include how to set up PrimAITE session via Python + - Examples include how to create nodes and install software via Python + - Added missing ``DoSBot`` documentation page + - Added diagrams where needed to make understanding some things easier + - Templated parts of the documentation to prevent unnecessary repetition and for easier maintaining of documentation + - Separated documentation pages of some items i.e. client and server software were on the same pages - which may make things confusing + - Configuration section at the bottom of the software pages specifying the configuration options available (and which ones are optional) +- Ability to add ``Firewall`` node via config +- Ability to add ``Router`` routes via config +- Ability to add ``Router``/``Firewall`` ``ACLRule`` via config +- NMNE capturing capabilities to `NetworkInterface` class for detecting and logging Malicious Network Events. +- New `nmne_config` settings in the simulation configuration to enable NMNE capturing and specify keywords such as "DELETE". ### Changed - Integrated the RouteTable into the Routers frame processing. @@ -93,7 +118,9 @@ SessionManager. - Refactored all tests to utilise new `Node` subclasses (`Computer`, `Server`, `Router`, `Switch`) instead of creating generic `Node` instances and manually adding network interfaces. This change aligns test setups more closely with the intended use cases and hierarchies within the network simulation framework. - Updated all tests to employ the `Network()` class for managing nodes and their connections, ensuring a consistent and structured approach to setting up network topologies in testing scenarios. - **ACLRule Wildcard Masking**: Updated the `ACLRule` class to support IP ranges using wildcard masking. This enhancement allows for more flexible and granular control over traffic filtering, enabling the specification of broader or more specific IP address ranges in ACL rules. - +- Updated `NetworkInterface` documentation to reflect the new NMNE capturing features and how to use them. +- Integration of NMNE capturing functionality within the `NicObservation` class. +- Changed blue action set to enable applying node scan, reset, start, and shutdown to every host in data manipulation scenario ### Removed - Removed legacy simulation modules: `acl`, `common`, `environment`, `links`, `nodes`, `pol` @@ -103,7 +130,7 @@ SessionManager. ### Fixed - Addressed network transmission issues that previously allowed ARP requests to be incorrectly routed and repeated across different subnets. This fix ensures ARP requests are correctly managed and confined to their appropriate network segments. - Resolved problems in `Node` and its subclasses where the default gateway configuration was not properly utilized for communications across different subnets. This correction ensures that nodes effectively use their configured default gateways for outbound communications to other network segments, thereby enhancing the network's routing functionality and reliability. - +- Network Interface Port name/num being set properly for sys log and PCAP output. ## [2.0.0] - 2023-07-26 diff --git a/docs/Makefile b/docs/Makefile index dd71ec33..82719283 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -6,7 +6,7 @@ SPHINXBUILD ?= sphinx-build SOURCEDIR = . BUILDDIR = _build -AUTOSUMMARY="source\_autosummary" +AUTOSUMMARY="source/_autosummary" # Remove command is different depending on OS ifdef OS diff --git a/docs/_static/firewall_acl.png b/docs/_static/firewall_acl.png new file mode 100644 index 00000000..1cdd2526 Binary files /dev/null and b/docs/_static/firewall_acl.png differ diff --git a/docs/_static/notebooks/extensions.png b/docs/_static/notebooks/extensions.png new file mode 100644 index 00000000..8441802d Binary files /dev/null and b/docs/_static/notebooks/extensions.png differ diff --git a/docs/_static/notebooks/install_extensions.png b/docs/_static/notebooks/install_extensions.png new file mode 100644 index 00000000..db026ce3 Binary files /dev/null and b/docs/_static/notebooks/install_extensions.png differ diff --git a/docs/_static/switched_p2p_network.png b/docs/_static/switched_p2p_network.png new file mode 100644 index 00000000..d1769942 Binary files /dev/null and b/docs/_static/switched_p2p_network.png differ diff --git a/docs/api.rst b/docs/api.rst index aeaef4e2..13f3a1ec 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,3 +1,5 @@ +:orphan: + .. only:: comment © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK diff --git a/docs/conf.py b/docs/conf.py index efd60b49..a666e460 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -10,12 +10,12 @@ import datetime # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information import os import sys +from typing import Any import furo # noqa sys.path.insert(0, os.path.abspath("../")) - # -- Project information ----------------------------------------------------- year = datetime.datetime.now().year project = "PrimAITE" @@ -28,6 +28,11 @@ with open("../src/primaite/VERSION", "r") as file: # The full version, including alpha/beta/rc tags release = version +# set global variables +rst_prolog = f""" +.. |VERSION| replace:: {release} +""" + html_title = f"{project} v{release} docs" # -- General configuration --------------------------------------------------- @@ -45,13 +50,35 @@ extensions = [ "sphinx_copybutton", # Adds a copy button to code blocks ] - templates_path = ["_templates"] -exclude_patterns = ["_build", "Thumbs.db", ".DS_Store"] - +exclude_patterns = [ + "_build", + "Thumbs.db", + ".DS_Store", +] # -- Options for HTML output ------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output html_theme = "furo" html_static_path = ["_static"] +html_theme_options = {"globaltoc_collapse": True, "globaltoc_maxdepth": 2} +html_copy_source = False + + +def replace_token(app: Any, docname: Any, source: Any): + """Replaces a token from the list of tokens.""" + result = source[0] + for key in app.config.tokens: + result = result.replace(key, app.config.tokens[key]) + source[0] = result + + +tokens = {"{VERSION}": release} # Token VERSION is replaced by the value of the PrimAITE version in the version file +"""Dict containing the tokens that need to be replaced in documentation.""" + + +def setup(app: Any): + """Custom setup for sphinx.""" + app.add_config_value("tokens", {}, True) + app.connect("source-read", replace_token) diff --git a/docs/index.rst b/docs/index.rst index 9eae8adc..4cc81b13 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -105,9 +105,12 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE! source/getting_started source/primaite_session + source/example_notebooks source/simulation source/game_layer source/config + source/environment + source/customising_scenarios .. toctree:: :caption: Developer information: diff --git a/docs/source/config.rst b/docs/source/config.rst index 575a3139..89181a24 100644 --- a/docs/source/config.rst +++ b/docs/source/config.rst @@ -1,102 +1,40 @@ -Primaite v3 config -****************** +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +PrimAITE |VERSION| Configuration +******************************** PrimAITE uses a single configuration file to define everything needed to train and evaluate an RL policy in a custom cybersecurity scenario. This includes the configuration of the network, the scripted or trained agents that interact with the network, as well as settings that define how to perform training in Stable Baselines 3 or Ray RLLib. The entire config is used by the ``PrimaiteSession`` object for users who wish to let PrimAITE handle the agent definition and training. If you wish to define custom agents and control the training loop yourself, you can use the config with the ``PrimaiteGame``, and ``PrimaiteGymEnv`` objects instead. That way, only the network configuration and agent setup parts of the config are used, and the training section is ignored. +Example Configuration Hierarchy +############################### +The top level configuration items in a configuration file is as follows + +.. code-block:: yaml + + training_config: + ... + io_settings: + ... + game: + ... + agents: + ... + simulation: + ... + +These are expanded upon in the Configurable items section below + Configurable items -================== +################## -``training_config`` -------------------- -This section allows selecting which training framework and algorithm to use, and set some training hyperparameters. +.. toctree:: + :maxdepth: 1 -``io_settings`` ---------------- -This section configures how PrimAITE saves data during simulation and training. - -**save_final_model**: Only used if training with PrimaiteSession, if true, the policy will be saved after the final training iteration. - -**save_checkpoints**: Only used if training with PrimaiteSession, if true, the policy will be saved periodically during training. - -**checkpoint_interval**: Only used if training with PrimaiteSession and if ``save_checkpoints`` is true. Defines how often to save the policy during training. - -**save_logs**: *currently unused*. - -**save_transactions**: *currently unused*. - -**save_tensorboard_logs**: *currently unused*. - -**save_step_metadata**: Whether to save the RL agents' action, environment state, and other data at every single step. - -**save_pcap_logs**: Whether to save pcap files of all network traffic during the simulation. - -**save_sys_logs**: Whether to save system logs from all nodes during the simulation. - -``game`` --------- -This section defines high-level settings that apply across the game, currently it's used to help shape the action and observation spaces by restricting which ports and internet protocols should be considered. Here, users can also set the maximum number of steps in an episode. - -``agents`` ----------- -Agents can be scripted (deterministic and stochastic), or controlled by a reinforcement learning algorithm. Not to be confused with an RL agent, the term agent here is used to refer to an entity that sends requests to the simulated network. In this part of the config, each agent's action space, observation space, and reward function can be defined. All three are defined in a modular way. - -**type**: Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``GreenWebBrowsingAgent`` generate their own behaviour. - -**team**: Specifies if the agent is malicious (RED), benign (GREEN), or defensive (BLUE). Currently this value is not used for anything. - -**observation space:** - * ``type``: selects which python class from the ``primaite.game.agent.observation`` module is used for the overall observation structure. - * ``options``: allows configuring the chosen observation type. The ``UC2BlueObservation`` should be used for RL Agents. - * ``num_services_per_node``, ``num_folders_per_node``, ``num_files_per_folder``, ``num_nics_per_node`` all define the shape of the observation space. The size and shape of the obs space must remain constant, but the number of files, folders, ACL rules, and other components can change within an episode. Therefore padding is performed and these options set the size of the obs space. - * ``nodes``: list of nodes that will be present in this agent's observation space. The ``node_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. Each node can also be configured with services, and files that should be monitored. - * ``links``: list of links that will be present in this agent's observation space. The ``link_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. - * ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's ACL table should be used. ``ip_address_order`` sets the encoding of ip addresses as integers within the observation space. - -**action space:** -The action space is configured to be made up of individual action types. Once configured, the agent can select an action type and some optional action parameters at every step. For example: The ``NODE_SERVICE_SCAN`` action takes the parameters ``node_id`` and ``service_id``. - -Description of configurable items: - * ``action_list``: a list of action modules. The options are listed in the ``primaite.game.agent.actions`` module. - * ``action_map``: (optional). Restricts the possible combinations of action type / action parameter values to reduce the overall size of the action space. By default, every possible combination of actions and parameters will be assigned an integer for the agent's ``MultiDiscrete`` action space. Instead, the ``action_map`` allows you to list the actions corresponding to each integer in the ``MultiDiscrete`` space. - * ``options``: Options that apply too all action components. - * ``nodes``: list the nodes that the agent can act on, the order of this list defines the mapping between nodes and ``node_id`` integers. - * ``max_folders_per_node``, ``max_files_per_folder``, ``max_services_per_node``, ``max_nics_per_node``, ``max_acl_rules`` all are used to define the size of the action space. - -**reward function:** -Similar to action space, this is defined as a list of components. - -Description of configurable items: - * ``reward_components`` a list of reward components from the ``primaite.game.agent.reward`` module. - * ``weight``: relative importance of this reward component. The total reward for a step is a weighted sum of all reward components. - * ``options``: list of options passed to the reward component during initialisation, the exact options required depend on the reward component. - -**agent_settings**: -Settings passed to the agent during initialisation. These depend on the agent class. - -Reinforcement learning agents use the ``ProxyAgent`` class, they accept these agent settings: - -**flatten_obs**: If true, gymnasium flattening will be performed on the observation space before sending to the agent. Set this to true if your agent does not support nested observation spaces. - -``simulation`` --------------- -In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents. - -At the top level of the network are ``nodes`` and ``links``. - -**nodes:** - * ``type``: one of ``router``, ``switch``, ``computer``, or ``server``, this affects what other sub-options should be defined. - * ``hostname`` - a non-unique name used for logging and outputs. - * ``num_ports`` (optional, routers and switches only): number of network interfaces present on the device. - * ``ports`` (optional, routers and switches only): configuration for each network interface, including IP address and subnet mask. - * ``acl`` (Router only): Define the ACL rules at each index of the ACL on the router. the possible options are: ``action`` (PERMIT or DENY), ``src_port``, ``dst_port``, ``protocol``, ``src_ip``, ``dst_ip``. Any options left blank default to none which usually means that it will apply across all options. For example leaving ``src_ip`` blank will apply the rule to all IP addresses. - * ``services`` (computers and servers only): a list of services to install on the node. They must define a ``ref``, ``type``, and ``options`` that depend on which ``type`` was selected. - * ``applications`` (computer and servers only): Similar to services. A list of application to install on the node. - * ``network_interfaces`` (computers and servers only): If the node has multiple networking devices, the second, third, fourth, etc... must be defined here with an ``ip_address`` and ``subnet_mask``. - -**links:** - * ``ref``: unique identifier for this link - * ``endpoint_a_ref``: Reference to the node at the first end of the link - * ``endpoint_a_port``: The ethernet port or switch port index of the second node - * ``endpoint_b_ref``: Reference to the node at the second end of the link - * ``endpoint_b_port``: The ethernet port or switch port index on the second node + configuration/training_config.rst + configuration/io_settings.rst + configuration/game.rst + configuration/agents.rst + configuration/simulation.rst diff --git a/docs/source/configuration/agents.rst b/docs/source/configuration/agents.rst new file mode 100644 index 00000000..b8912883 --- /dev/null +++ b/docs/source/configuration/agents.rst @@ -0,0 +1,174 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + + +``agents`` +========== +Agents can be scripted (deterministic and stochastic), or controlled by a reinforcement learning algorithm. Not to be confused with an RL agent, the term agent here is used to refer to an entity that sends requests to the simulated network. In this part of the config, each agent's action space, observation space, and reward function can be defined. All three are defined in a modular way. + +``agents`` hierarchy +-------------------- + +.. code-block:: yaml + + agents: + - ref: red_agent_example + ... + - ref: blue_agent_example + ... + - ref: green_agent_example + team: GREEN + type: ProbabilisticAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 + flatten_obs: False + +``ref`` +------- +The reference to be used for the given agent. + +``team`` +-------- +Specifies if the agent is malicious (``RED``), benign (``GREEN``), or defensive (``BLUE``). Currently this value is not used for anything other than for human readability in the configuration file. + +``type`` +-------- +Specifies which class should be used for the agent. ``ProxyAgent`` is used for agents that receive instructions from an RL algorithm. Scripted agents like ``RedDatabaseCorruptingAgent`` and ``ProbabilisticAgent`` generate their own behaviour. + +Available agent types: + +- ``ProbabilisticAgent`` +- ``ProxyAgent`` +- ``RedDatabaseCorruptingAgent`` + +``observation_space`` +--------------------- +Defines the observation space of the agent. + +``type`` +^^^^^^^^ + +selects which python class from the :py:mod:`primaite.game.agent.observation` module is used for the overall observation structure. + +``options`` +^^^^^^^^^^^ + +Allows configuration of the chosen observation type. These are optional. + + * ``num_services_per_node``, ``num_folders_per_node``, ``num_files_per_folder``, ``num_nics_per_node`` all define the shape of the observation space. The size and shape of the obs space must remain constant, but the number of files, folders, ACL rules, and other components can change within an episode. Therefore padding is performed and these options set the size of the obs space. + * ``nodes``: list of nodes that will be present in this agent's observation space. The ``node_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. Each node can also be configured with services, and files that should be monitored. + * ``links``: list of links that will be present in this agent's observation space. The ``link_ref`` relates to the human-readable unique reference defined later in the ``simulation`` part of the config. + * ``acl``: configure how the agent reads the access control list on the router in the simulation. ``router_node_ref`` is for selecting which router's ACL table should be used. ``ip_address_order`` sets the encoding of ip addresses as integers within the observation space. + +For more information see :py:mod:`primaite.game.agent.observations` + +``action_space`` +---------------- + +The action space is configured to be made up of individual action types. Once configured, the agent can select an action type and some optional action parameters at every step. For example: The ``NODE_SERVICE_SCAN`` action takes the parameters ``node_id`` and ``service_id``. + +``action_list`` +^^^^^^^^^^^^^^^ + +A list of action modules. The options are listed in the :py:mod:`primaite.game.agent.actions.ActionManager.act_class_identifiers` module. + +``action_map`` +^^^^^^^^^^^^^^ + +Restricts the possible combinations of action type / action parameter values to reduce the overall size of the action space. By default, every possible combination of actions and parameters will be assigned an integer for the agent's ``MultiDiscrete`` action space. Instead, the ``action_map`` allows you to list the actions corresponding to each integer in the ``MultiDiscrete`` space. + +This is Optional. + +``options`` +^^^^^^^^^^^ + +Options that apply to all action components. These are optional. + + * ``nodes``: list the nodes that the agent can act on, the order of this list defines the mapping between nodes and ``node_id`` integers. + * ``max_folders_per_node``, ``max_files_per_folder``, ``max_services_per_node``, ``max_nics_per_node``, ``max_acl_rules`` all are used to define the size of the action space. + +For more information see :py:mod:`primaite.game.agent.actions` + +``reward_function`` +------------------- + +Similar to action space, this is defined as a list of components from the :py:mod:`primaite.game.agent.rewards` module. + +``reward_components`` +^^^^^^^^^^^^^^^^^^^^^ + +A list of reward types from :py:mod:`primaite.game.agent.rewards.RewardFunction.rew_class_identifiers` + +e.g. + +.. code-block:: yaml + + reward_components: + - type: DUMMY + - type: DATABASE_FILE_INTEGRITY + + +``agent_settings`` +------------------ + +Settings passed to the agent during initialisation. Determines how the agent will behave during training. + +e.g. + +.. code-block:: yaml + + agent_settings: + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + +``start_step`` +^^^^^^^^^^^^^^ + +Optional. Default value is ``5``. + +The timestep where the agent begins performing actions. + +``frequency`` +^^^^^^^^^^^^^ + +Optional. Default value is ``5``. + +The number of timesteps the agent will wait before performing another action. + +``variance`` +^^^^^^^^^^^^ + +Optional. Default value is ``0``. + +The amount of timesteps that the frequency can randomly change. + +``flatten_obs`` +--------------- + +If ``True``, gymnasium flattening will be performed on the observation space before sending to the agent. Set this to ``True`` if your agent does not support nested observation spaces. diff --git a/docs/source/configuration/game.rst b/docs/source/configuration/game.rst new file mode 100644 index 00000000..828571a7 --- /dev/null +++ b/docs/source/configuration/game.rst @@ -0,0 +1,56 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + + +``game`` +======== +This section defines high-level settings that apply across the game, currently it's used to help shape the action and observation spaces by restricting which ports and internet protocols should be considered. Here, users can also set the maximum number of steps in an episode. + +``game`` hierarchy +------------------ + +.. code-block:: yaml + + game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +``max_episode_length`` +---------------------- + +Optional. Default value is ``256``. + +The maximum number of episodes a Reinforcement Learning agent(s) can be trained for. + +``ports`` +--------- + +A list of ports that the Reinforcement Learning agent(s) are able to see in the observation space. + +See :ref:`List of Ports ` for a list of ports. + +``protocols`` +------------- + +A list of protocols that the Reinforcement Learning agent(s) are able to see in the observation space. + +See :ref:`List of IPProtocols ` for a list of protocols. + +``thresholds`` +-------------- + +These are used to determine the thresholds of high, medium and low categories for counted observation occurrences. diff --git a/docs/source/configuration/io_settings.rst b/docs/source/configuration/io_settings.rst new file mode 100644 index 00000000..979dbfae --- /dev/null +++ b/docs/source/configuration/io_settings.rst @@ -0,0 +1,87 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + + +``io_settings`` +=============== +This section configures how PrimAITE saves data during simulation and training. + +``io_settings`` hierarchy +------------------------- + +.. code-block:: yaml + + io_settings: + save_final_model: True + save_checkpoints: False + checkpoint_interval: 10 + # save_logs: True + # save_transactions: False + save_agent_actions: True + save_step_metadata: False + save_pcap_logs: False + save_sys_logs: False + +``save_final_model`` +-------------------- + +Optional. Default value is ``True``. + +Only used if training with PrimaiteSession. +If ``True``, the policy will be saved after the final training iteration. + + +``save_checkpoints`` +-------------------- + +Optional. Default value is ``False``. + +Only used if training with PrimaiteSession. +If ``True``, the policy will be saved periodically during training. + + +``checkpoint_interval`` +----------------------- + +Optional. Default value is ``10``. + +Only used if training with PrimaiteSession and if ``save_checkpoints`` is ``True``. +Defines how often to save the policy during training. + + +``save_logs`` +------------- + +*currently unused*. + + +``save_agent_actions`` +---------------------- + +Optional. Default value is ``True``. + +If ``True``, this will create a JSON file each episode detailing every agent's action in each step of that episode, formatted according to the CAOS format. This includes scripted, RL, and red agents. + +``save_step_metadata`` +---------------------- + +Optional. Default value is ``False``. + +If ``True``, The RL agent(s) actions, environment states and other data will be saved at every single step. + + +``save_pcap_logs`` +------------------ + +Optional. Default value is ``False``. + +If ``True``, then the pcap files which contain all network traffic during the simulation will be saved. + + +``save_sys_logs`` +----------------- + +Optional. Default value is ``False``. + +If ``True``, then the log files which contain all node actions during the simulation will be saved. diff --git a/docs/source/configuration/simulation.rst b/docs/source/configuration/simulation.rst new file mode 100644 index 00000000..e2fa5476 --- /dev/null +++ b/docs/source/configuration/simulation.rst @@ -0,0 +1,93 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + + +``simulation`` +============== +In this section the network layout is defined. This part of the config follows a hierarchical structure. Almost every component defines a ``ref`` field which acts as a human-readable unique identifier, used by other parts of the config, such as agents. + +At the top level of the network are ``nodes`` and ``links``. + +e.g. + +.. code-block:: yaml + + simulation: + network: + nodes: + ... + links: + ... + +``nodes`` +--------- + +This is where the list of nodes are defined. Some items will differ according to the node type, however, there will be common items such as a node's reference (which is used by the agent), the node's ``type`` and ``hostname`` + +To see the configuration for these nodes, refer to the following: + +.. toctree:: + :maxdepth: 1 + :glob: + + simulation/nodes/* + +``links`` +--------- + +This is where the links between the nodes are formed. + +e.g. + +In order to recreate the network below, we will need to create 2 links: + +- a link from computer_1 to the switch +- a link from computer_2 to the switch + +.. image:: ../../_static/switched_p2p_network.png + :width: 500 + :align: center + +this results in: + +.. code-block:: yaml + + links: + - ref: computer_1___switch + endpoint_a_ref: computer_1 + endpoint_a_port: 1 # port 1 on computer_1 + endpoint_b_ref: switch + endpoint_b_port: 1 # port 1 on switch + - ref: computer_2___switch + endpoint_a_ref: computer_2 + endpoint_a_port: 1 # port 1 on computer_2 + endpoint_b_ref: switch + endpoint_b_port: 2 # port 2 on switch + +``ref`` +^^^^^^^ + +The human readable name for the link. Not used in code, however is useful for a human to understand what the link is for. + +``endpoint_a_ref`` +^^^^^^^^^^^^^^^^^^ + +The ``hostname`` of the node which must be connected. + +``endpoint_a_port`` +^^^^^^^^^^^^^^^^^^^ + +The port on ``endpoint_a_ref`` which is to be connected to ``endpoint_b_port``. +This accepts an integer value e.g. if port 1 is to be connected, the configuration should be ``endpoint_a_port: 1`` + +``endpoint_b_ref`` +^^^^^^^^^^^^^^^^^^ + +The ``hostname`` of the node which must be connected. + +``endpoint_b_port`` +^^^^^^^^^^^^^^^^^^^ + +The port on ``endpoint_b_ref`` which is to be connected to ``endpoint_a_port``. +This accepts an integer value e.g. if port 1 is to be connected, the configuration should be ``endpoint_b_port: 1`` diff --git a/docs/source/configuration/simulation/nodes/common/common.rst b/docs/source/configuration/simulation/nodes/common/common.rst new file mode 100644 index 00000000..d1c8f307 --- /dev/null +++ b/docs/source/configuration/simulation/nodes/common/common.rst @@ -0,0 +1,35 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _Node Attributes: + +Common Attributes +################# + +Node Attributes +=============== + +Attributes that are shared by all nodes. + +.. include:: common_node_attributes.rst + +.. _Network Node Attributes: + +Network Node Attributes +======================= + +Attributes that are shared by nodes that inherit from :py:mod:`primaite.simulator.network.hardware.nodes.network.network_node.NetworkNode` + +.. include:: common_host_node_attributes.rst + +.. _Host Node Attributes: + +Host Node Attributes +==================== + +Attributes that are shared by nodes that inherit from :py:mod:`primaite.simulator.network.hardware.nodes.host.host_node.HostNode` + +.. include:: common_host_node_attributes.rst + +.. |NODE| replace:: node diff --git a/docs/source/configuration/simulation/nodes/common/common_host_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_host_node_attributes.rst new file mode 100644 index 00000000..929d5714 --- /dev/null +++ b/docs/source/configuration/simulation/nodes/common/common_host_node_attributes.rst @@ -0,0 +1,26 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _common_host_node_attributes: + +``ip_address`` +-------------- + +The IP address of the |NODE| in the network. + +``subnet_mask`` +--------------- + +Optional. Default value is ``255.255.255.0``. + +The subnet mask for the |NODE| to use. + +``default_gateway`` +------------------- + +The IP address that the |NODE| will use as the default gateway. Typically, this is the IP address of the closest router that the |NODE| is connected to. + +.. include:: ../software/applications.rst + +.. include:: ../software/services.rst diff --git a/docs/source/configuration/simulation/nodes/common/common_network_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_network_node_attributes.rst new file mode 100644 index 00000000..1161059f --- /dev/null +++ b/docs/source/configuration/simulation/nodes/common/common_network_node_attributes.rst @@ -0,0 +1,51 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _common_network_node_attributes: + +``routes`` +---------- + +A list of routes which tells the |NODE| where to forward the packet to depending on the target IP address. + +e.g. + +.. code-block:: yaml + + nodes: + - ref: node + ... + routes: + - address: 192.168.0.10 + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.1 + metric: 0 + +``address`` +""""""""""" + +The target IP address for the route. If the packet destination IP address matches this, the |NODE| will route the packet according to the ``next_hop_ip_address``. + +This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. + +``subnet_mask`` +""""""""""""""" + +Optional. Default value is ``255.255.255.0``. + +The subnet mask setting for the route. + +``next_hop_ip_address`` +""""""""""""""""""""""" + +The IP address of the next hop IP address that the packet will follow if the address matches the packet's destination IP address. + +This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. + +``metric`` +"""""""""" + +Optional. Default value is ``0``. This value accepts floats. + +The cost or distance of a route. The higher the value, the more cost or distance is attributed to the route. diff --git a/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst new file mode 100644 index 00000000..34519adc --- /dev/null +++ b/docs/source/configuration/simulation/nodes/common/common_node_attributes.rst @@ -0,0 +1,55 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _common_node_attributes: + +``ref`` +------- + +Human readable name used as reference for the |NODE|. Not used in code. + +``hostname`` +------------ + +The hostname of the |NODE|. This will be used to reference the |NODE|. + +``operating_state`` +------------------- + +The initial operating state of the node. + +Optional. Default value is ``ON``. + +Options available are: + +- ``ON`` +- ``OFF`` +- ``BOOTING`` +- ``SHUTTING_DOWN`` + +Note that YAML may assume non quoted ``ON`` and ``OFF`` as ``True`` and ``False`` respectively. To prevent this, use ``"ON"`` or ``"OFF"`` + +See :py:mod:`primaite.simulator.network.hardware.node_operating_state.NodeOperatingState` + + +``dns_server`` +-------------- + +Optional. Default value is ``None``. + +The IP address of the node which holds an instance of the :ref:`DNSServer`. Some applications may use a domain name e.g. the :ref:`WebBrowser` + +``start_up_duration`` +--------------------- + +Optional. Default value is ``3``. + +The number of time steps required to occur in order for the node to cycle from ``OFF`` to ``BOOTING_UP`` and then finally ``ON``. + +``shut_down_duration`` +---------------------- + +Optional. Default value is ``3``. + +The number of time steps required to occur in order for the node to cycle from ``ON`` to ``SHUTTING_DOWN`` and then finally ``OFF``. diff --git a/docs/source/configuration/simulation/nodes/common/node_type_list.rst b/docs/source/configuration/simulation/nodes/common/node_type_list.rst new file mode 100644 index 00000000..ceee8207 --- /dev/null +++ b/docs/source/configuration/simulation/nodes/common/node_type_list.rst @@ -0,0 +1,18 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +``type`` +-------- + +The type of node to add. + +Available options are: + +- ``computer`` +- ``firewall`` +- ``router`` +- ``server`` +- ``switch`` + +To create a |NODE|, type must be |NODE_TYPE|. diff --git a/docs/source/configuration/simulation/nodes/computer.rst b/docs/source/configuration/simulation/nodes/computer.rst new file mode 100644 index 00000000..04a45766 --- /dev/null +++ b/docs/source/configuration/simulation/nodes/computer.rst @@ -0,0 +1,41 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _computer_configuration: + +``computer`` +============ + +A basic representation of a computer within the simulation. + +See :py:mod:`primaite.simulator.network.hardware.nodes.host.computer.Computer` + +example computer +---------------- + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: client_1 + hostname: client_1 + type: computer + ip_address: 192.168.0.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.0.1 + dns_server: 192.168.1.10 + applications: + ... + services: + ... + +.. include:: common/common_node_attributes.rst + +.. include:: common/node_type_list.rst + +.. include:: common/common_host_node_attributes.rst + +.. |NODE| replace:: computer +.. |NODE_TYPE| replace:: ``computer`` diff --git a/docs/source/configuration/simulation/nodes/firewall.rst b/docs/source/configuration/simulation/nodes/firewall.rst new file mode 100644 index 00000000..47db4001 --- /dev/null +++ b/docs/source/configuration/simulation/nodes/firewall.rst @@ -0,0 +1,300 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _firewall_configuration: + +``firewall`` +============ + +A basic representation of a network firewall within the simulation. + +The firewall is similar to how :ref:`Router ` works, with the difference being how firewall has specific ACL rules for inbound and outbound traffic as well as firewall being limited to 3 ports. + +See :py:mod:`primaite.simulator.network.hardware.nodes.network.firewall.Firewall` + +example firewall +---------------- + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: firewall + hostname: firewall + type: firewall + start_up_duration: 0 + shut_down_duration: 0 + ports: + external_port: # port 1 + ip_address: 192.168.20.1 + subnet_mask: 255.255.255.0 + internal_port: # port 2 + ip_address: 192.168.1.2 + subnet_mask: 255.255.255.0 + dmz_port: # port 3 + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + internal_inbound_acl: + ... + internal_outbound_acl: + ... + dmz_inbound_acl: + ... + dmz_outbound_acl: + ... + external_inbound_acl: + ... + external_outbound_acl: + ... + routes: + ... + +.. include:: common/common_node_attributes.rst + +.. include:: common/node_type_list.rst + +``ports`` +--------- + +The firewall node only has 3 ports. These specifically are: + +- ``external_port`` (port 1) +- ``internal_port`` (port 2) +- ``dmz_port`` (port 3) (can be optional) + +The ports should be defined with an ip address and subnet mask e.g. + +.. code-block:: yaml + + nodes: + - ref: firewall + ... + ports: + external_port: # port 1 + ip_address: 192.168.20.1 + subnet_mask: 255.255.255.0 + internal_port: # port 2 + ip_address: 192.168.1.2 + subnet_mask: 255.255.255.0 + dmz_port: # port 3 + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + +``ip_address`` +"""""""""""""" + +The IP address for the given port. This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. + +``subnet_mask`` +""""""""""""""" + +Optional. Default value is ``255.255.255.0``. + +The subnet mask setting for the port. + +``acl`` +------- + +There are 6 ACLs that can be defined for a firewall + +- ``internal_inbound_acl`` for traffic going towards the internal network +- ``internal_outbound_acl`` for traffic coming from the internal network +- ``dmz_inbound_acl`` for traffic going towards the dmz network +- ``dmz_outbound_acl`` for traffic coming from the dmz network +- ``external_inbound_acl`` for traffic coming from the external network +- ``external_outbound_acl`` for traffic going towards the external network + +.. image:: ../../../../_static/firewall_acl.png + :width: 500 + :align: center + +By default, ``external_inbound_acl`` and ``external_outbound_acl`` will permit any traffic through. + +``internal_inbound_acl``, ``internal_outbound_acl``, ``dmz_inbound_acl`` and ``dmz_outbound_acl`` will deny any traffic by default, so must be configured to allow defined ``src_port`` and ``dst_port`` or ``protocol``. + +See :py:mod:`primaite.simulator.network.hardware.nodes.network.router.AccessControlList` + +See :ref:`List of Ports ` for a list of ports. + +``internal_inbound_acl`` +"""""""""""""""""""""""" + +ACL rules for packets that have a destination IP address in what is considered the internal network. + +example: + +.. code-block:: yaml + + nodes: + - ref: firewall + ... + acl: + internal_inbound_acl: + 21: # position 21 on ACL list + action: PERMIT # allow packets that + src_port: POSTGRES_SERVER # are emitted from the POSTGRES_SERVER port + dst_port: POSTGRES_SERVER # are going towards an POSTGRES_SERVER port + 22: # position 22 on ACL list + action: PERMIT # allow packets that + src_port: ARP # are emitted from the ARP port + dst_port: ARP # are going towards an ARP port + 23: # position 23 on ACL list + action: PERMIT # allow packets that + protocol: ICMP # are ICMP + +``internal_outbound_acl`` +""""""""""""""""""""""""" + +ACL rules for packets that have a source IP address in what is considered the internal network and is going towards the DMZ network or the external network. + +example: + +.. code-block:: yaml + + nodes: + - ref: firewall + ... + acl: + internal_outbound_acl: + 21: # position 21 on ACL list + action: PERMIT # allow packets that + src_port: POSTGRES_SERVER # are emitted from the POSTGRES_SERVER port + dst_port: POSTGRES_SERVER # are going towards an POSTGRES_SERVER port + 22: # position 22 on ACL list + action: PERMIT # allow packets that + src_port: ARP # are emitted from the ARP port + dst_port: ARP # are going towards an ARP port + 23: # position 23 on ACL list + action: PERMIT # allow packets that + protocol: ICMP # are ICMP + + +``dmz_inbound_acl`` +""""""""""""""""""" + +ACL rules for packets that have a destination IP address in what is considered the DMZ network. + +example: + +.. code-block:: yaml + + nodes: + - ref: firewall + ... + acl: + dmz_inbound_acl: + 19: # position 19 on ACL list + action: PERMIT # allow packets that + src_port: POSTGRES_SERVER # are emitted from the POSTGRES_SERVER port + dst_port: POSTGRES_SERVER # are going towards an POSTGRES_SERVER port + 20: # position 20 on ACL list + action: PERMIT # allow packets that + src_port: HTTP # are emitted from the HTTP port + dst_port: HTTP # are going towards an HTTP port + 21: # position 21 on ACL list + action: PERMIT # allow packets that + src_port: HTTPS # are emitted from the HTTPS port + dst_port: HTTPS # are going towards an HTTPS port + 22: # position 22 on ACL list + action: PERMIT # allow packets that + src_port: ARP # are emitted from the ARP port + dst_port: ARP # are going towards an ARP port + 23: # position 23 on ACL list + action: PERMIT # allow packets that + protocol: ICMP # are ICMP + +``dmz_outbound_acl`` +"""""""""""""""""""" + +ACL rules for packets that have a source IP address in what is considered the DMZ network and is going towards the internal network or the external network. + +example: + +.. code-block:: yaml + + nodes: + - ref: firewall + ... + acl: + dmz_outbound_acl: + 19: # position 19 on ACL list + action: PERMIT # allow packets that + src_port: POSTGRES_SERVER # are emitted from the POSTGRES_SERVER port + dst_port: POSTGRES_SERVER # are going towards an POSTGRES_SERVER port + 20: # position 20 on ACL list + action: PERMIT # allow packets that + src_port: HTTP # are emitted from the HTTP port + dst_port: HTTP # are going towards an HTTP port + 21: # position 21 on ACL list + action: PERMIT # allow packets that + src_port: HTTPS # are emitted from the HTTPS port + dst_port: HTTPS # are going towards an HTTPS port + 22: # position 22 on ACL list + action: PERMIT # allow packets that + src_port: ARP # are emitted from the ARP port + dst_port: ARP # are going towards an ARP port + 23: # position 23 on ACL list + action: PERMIT # allow packets that + protocol: ICMP # are ICMP + + + +``external_inbound_acl`` +"""""""""""""""""""""""" + +Optional. By default, this will allow any traffic through. + +ACL rules for packets that have a destination IP address in what is considered the external network. + +example: + +.. code-block:: yaml + + nodes: + - ref: firewall + ... + acl: + external_inbound_acl: + 21: # position 19 on ACL list + action: DENY # deny packets that + src_port: POSTGRES_SERVER # are emitted from the POSTGRES_SERVER port + dst_port: POSTGRES_SERVER # are going towards an POSTGRES_SERVER port + 22: # position 22 on ACL list + action: PERMIT # allow packets that + src_port: ARP # are emitted from the ARP port + dst_port: ARP # are going towards an ARP port + 23: # position 23 on ACL list + action: PERMIT # allow packets that + protocol: ICMP # are ICMP + +``external_outbound_acl`` +""""""""""""""""""""""""" + +Optional. By default, this will allow any traffic through. + +ACL rules for packets that have a source IP address in what is considered the external network and is going towards the DMZ network or the internal network. + +example: + +.. code-block:: yaml + + nodes: + - ref: firewall + ... + acl: + external_outbound_acl: + 22: # position 22 on ACL list + action: PERMIT # allow packets that + src_port: ARP # are emitted from the ARP port + dst_port: ARP # are going towards an ARP port + 23: # position 23 on ACL list + action: PERMIT # allow packets that + protocol: ICMP # are ICMP + +.. include:: common/common_network_node_attributes.rst + +.. |NODE| replace:: firewall +.. |NODE_TYPE| replace:: ``firewall`` diff --git a/docs/source/configuration/simulation/nodes/router.rst b/docs/source/configuration/simulation/nodes/router.rst new file mode 100644 index 00000000..b9ba1ad5 --- /dev/null +++ b/docs/source/configuration/simulation/nodes/router.rst @@ -0,0 +1,127 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _router_configuration: + +``router`` +========== + +A basic representation of a network router within the simulation. + +See :py:mod:`primaite.simulator.network.hardware.nodes.network.router.Router` + +example router +-------------- + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: router_1 + hostname: router_1 + type: router + num_ports: 5 + ports: + ... + acl: + ... + +.. include:: common/common_node_attributes.rst + +.. include:: common/node_type_list.rst + +``num_ports`` +------------- + +Optional. Default value is ``5``. + +The number of ports the router will have. + +``ports`` +--------- + +Sets up the router's ports with an IP address and a subnet mask. + +Example of setting ports for a router with 2 ports: + +.. code-block:: yaml + + nodes: + - ref: router_1 + ... + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + +``ip_address`` +"""""""""""""" + +The IP address for the given port. This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. + +``subnet_mask`` +""""""""""""""" + +Optional. Default value is ``255.255.255.0``. + +The subnet mask setting for the port. + +``acl`` +------- + +Sets up the ACL rules for the router. + +e.g. + +.. code-block:: yaml + + nodes: + - ref: router_1 + ... + acl: + 1: + action: PERMIT + src_port: ARP + dst_port: ARP + 2: + action: PERMIT + protocol: ICMP + +See :py:mod:`primaite.simulator.network.hardware.nodes.network.router.AccessControlList` + +See :ref:`List of Ports ` for a list of ports. + +``action`` +"""""""""" + +Available options are + +- ``PERMIT`` : Allows the specified ``protocol`` or ``src_port`` and ``dst_port`` pairs +- ``DENY`` : Blocks the specified ``protocol`` or ``src_port`` and ``dst_port`` pairs + +``src_port`` +"""""""""""" + +Is used alongside ``dst_port``. Specifies the port where a packet originates. Used by the ACL Rule to determine if a packet with a specific source port is allowed to pass through the network node. + +``dst_port`` +"""""""""""" + +Is used alongside ``src_port``. Specifies the port where a packet is destined to arrive. Used by the ACL Rule to determine if a packet with a specific destination port is allowed to pass through the network node. + +``protocol`` +"""""""""""" + +Specifies which protocols are allowed by the ACL Rule to pass through the network node. + +See :ref:`List of IPProtocols ` for a list of protocols. + +.. include:: common/common_network_node_attributes.rst + +.. |NODE| replace:: router +.. |NODE_TYPE| replace:: ``router`` diff --git a/docs/source/configuration/simulation/nodes/server.rst b/docs/source/configuration/simulation/nodes/server.rst new file mode 100644 index 00000000..dbc32235 --- /dev/null +++ b/docs/source/configuration/simulation/nodes/server.rst @@ -0,0 +1,41 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _server_configuration: + +``server`` +========== + +A basic representation of a server within the simulation. + +See :py:mod:`primaite.simulator.network.hardware.nodes.host.server.Server` + +example server +-------------- + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: server_1 + hostname: server_1 + type: server + ip_address: 192.168.10.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + ... + services: + ... + +.. include:: common/common_node_attributes.rst + +.. include:: common/node_type_list.rst + +.. include:: common/common_host_node_attributes.rst + +.. |NODE| replace:: server +.. |NODE_TYPE| replace:: ``server`` diff --git a/docs/source/configuration/simulation/nodes/switch.rst b/docs/source/configuration/simulation/nodes/switch.rst new file mode 100644 index 00000000..263bedbb --- /dev/null +++ b/docs/source/configuration/simulation/nodes/switch.rst @@ -0,0 +1,39 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _switch_configuration: + +``switch`` +========== + +A basic representation of a network switch within the simulation. + +See :py:mod:`primaite.simulator.network.hardware.nodes.network.switch.Switch` + +example switch +-------------- + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: switch_1 + hostname: switch_1 + type: switch + num_ports: 8 + +.. include:: common/common_node_attributes.rst + +.. include:: common/node_type_list.rst + +``num_ports`` +------------- + +Optional. Default value is ``8``. + +The number of ports the switch will have. + +.. |NODE| replace:: switch +.. |NODE_TYPE| replace:: ``switch`` diff --git a/docs/source/configuration/simulation/software/applications.rst b/docs/source/configuration/simulation/software/applications.rst new file mode 100644 index 00000000..90ae3ec1 --- /dev/null +++ b/docs/source/configuration/simulation/software/applications.rst @@ -0,0 +1,25 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +``applications`` +---------------- + +List of available applications that can be installed on a |NODE| can be found in :ref:`List of Applications ` + +application in configuration +"""""""""""""""""""""""""""" + +Applications takes a list of applications as shown in the example below. + +.. code-block:: yaml + + - ref: client_1 + hostname: client_1 + type: computer + ... + applications: + - ref: example_application + type: example_application_type + options: + # this section is different for each application diff --git a/docs/source/configuration/simulation/software/services.rst b/docs/source/configuration/simulation/software/services.rst new file mode 100644 index 00000000..88957001 --- /dev/null +++ b/docs/source/configuration/simulation/software/services.rst @@ -0,0 +1,25 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +``services`` +------------ + +List of available services that can be installed on a |NODE| can be found in :ref:`List of Services ` + +services in configuration +""""""""""""""""""""""""" + +Services takes a list of services as shown in the example below. + +.. code-block:: yaml + + - ref: client_1 + hostname: client_1 + type: computer + ... + applications: + - ref: example_service + type: example_service_type + options: + # this section is different for each service diff --git a/docs/source/configuration/training_config.rst b/docs/source/configuration/training_config.rst new file mode 100644 index 00000000..3e63f69b --- /dev/null +++ b/docs/source/configuration/training_config.rst @@ -0,0 +1,75 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +``training_config`` +=================== +Configuration items relevant to how the Reinforcement Learning agent(s) will be trained. + +``training_config`` hierarchy +----------------------------- + +.. code-block:: yaml + + training_config: + rl_framework: SB3 # or RLLIB_single_agent or RLLIB_multi_agent + rl_algorithm: PPO # or A2C + n_learn_episodes: 5 + max_steps_per_episode: 200 + n_eval_episodes: 1 + deterministic_eval: True + seed: 123 + + +``rl_framework`` +---------------- +The RL (Reinforcement Learning) Framework to use in the training session + +Options available are: + +- ``SB3`` (Stable Baselines 3) +- ``RLLIB_single_agent`` (Single Agent Ray RLLib) +- ``RLLIB_multi_agent`` (Multi Agent Ray RLLib) + +``rl_algorithm`` +---------------- +The Reinforcement Learning Algorithm to use in the training session + +Options available are: + +- ``PPO`` (Proximal Policy Optimisation) +- ``A2C`` (Advantage Actor Critic) + +``n_learn_episodes`` +-------------------- +The number of episodes to train the agent(s). +This should be an integer value above ``0`` + +``max_steps_per_episode`` +------------------------- +The number of steps each episode will last for. +This should be an integer value above ``0``. + + +``n_eval_episodes`` +------------------- +Optional. Default value is ``0``. + +The number of evaluation episodes to run the trained agent for. +This should be an integer value above ``0``. + +``deterministic_eval`` +---------------------- +Optional. By default this value is ``False``. + +If this is set to ``True``, the agents will act deterministically instead of stochastically. + + + +``seed`` +-------- +Optional. + +The seed is used (alongside ``deterministic_eval``) to reproduce a previous instance of training and evaluation of an RL agent. +The seed should be an integer value. +Useful for debugging. diff --git a/docs/source/customising_scenarios.rst b/docs/source/customising_scenarios.rst new file mode 100644 index 00000000..709f032a --- /dev/null +++ b/docs/source/customising_scenarios.rst @@ -0,0 +1,4 @@ +Customising Agents +****************** + +For an example of how to customise red agent behaviour in the Data Manipulation scenario, please refer to the notebook ``Data-Manipulation-Customising-Red-Agent.ipynb``. diff --git a/docs/source/dependencies.rst b/docs/source/dependencies.rst index 942ccfd8..ddea27fa 100644 --- a/docs/source/dependencies.rst +++ b/docs/source/dependencies.rst @@ -5,6 +5,8 @@ .. role:: raw-html(raw) :format: html +.. _Dependencies: + Dependencies ============ diff --git a/docs/source/environment.rst b/docs/source/environment.rst new file mode 100644 index 00000000..2b76572d --- /dev/null +++ b/docs/source/environment.rst @@ -0,0 +1,10 @@ +RL Environments +*************** + +RL environments are the objects that directly interface with RL libraries such as Stable-Baselines3 and Ray RLLib. The PrimAITE simulation is exposed via three different environment APIs: + +* Gymnasium API - this is the standard interface that works with many RL libraries like SB3, Ray, Tianshou, etc. ``PrimaiteGymEnv`` adheres to the `Official Gymnasium documentation `_. +* Ray Single agent API - For training a single Ray RLLib agent +* Ray MARL API - For training multi-agent systems with Ray RLLib. ``PrimaiteRayMARLEnv`` adheres to the `Official Ray documentation `_. + +There are Jupyter notebooks which demonstrate integration with each of these three environments. They are located in ``~/primaite//notebooks/example_notebooks``. diff --git a/docs/source/example_notebooks.rst b/docs/source/example_notebooks.rst new file mode 100644 index 00000000..99d47822 --- /dev/null +++ b/docs/source/example_notebooks.rst @@ -0,0 +1,77 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +Example Jupyter Notebooks +========================= + +There are a few example notebooks included which help with the understanding of PrimAITE's capabilities. + +The Jupyter Notebooks can be run via the 2 examples below. These assume that the instructions to install PrimAITE from the :ref:`Getting Started ` page is completed as a prerequisite. + +Running Jupyter Notebooks +------------------------- + +1. Navigate to the PrimAITE directory + +.. code-block:: bash + :caption: Unix + + cd ~/primaite/{VERSION} + +.. code-block:: powershell + :caption: Windows (Powershell) + + cd ~\primaite\{VERSION} + +2. Run jupyter notebook (the python environment to which you installed PrimAITE must be active) + +.. code-block:: bash + :caption: Unix + + jupyter notebook + +.. code-block:: powershell + :caption: Windows (Powershell) + + jupyter notebook + +3. Opening the jupyter webpage (optional) + +The default web browser may automatically open the webpage. However, if that is not the case, click the link shown in your command prompt output. It should look like this: ``http://localhost:8888/?token=0123456798abc0123456789abc`` + + +4. Navigate to the list of notebooks + +The example notebooks are located in ``notebooks/example_notebooks/``. The file system shown in the jupyter webpage is relative to the location in which the ``jupyter notebook`` command was used. + + +Running Jupyter Notebooks via VSCode +------------------------------------ + +It is also possible to view the Jupyter notebooks within VSCode. + +The best place to start is by opening a notebook file (.ipynb) in VSCode. If using VSCode to view a notebook for the first time, follow the steps below. + +Installing extensions +""""""""""""""""""""" + +VSCode may need some extensions to be installed if not already done. +To do this, press the "Select Kernel" button on the top right. + +This should open a dialog which has the option to install python and jupyter extensions. + +.. image:: ../../_static/notebooks/install_extensions.png + :width: 700 + :align: center + :alt: :: The top dialog option that appears will automatically install the extensions + +The following extensions should now be installed + +.. image:: ../../_static/notebooks/extensions.png + :width: 300 + :align: center + +VSCode will then ask for a Python environment version to use. PrimAITE is compatible with Python versions 3.8 - 3.10 + +You should now be able to interact with the notebook. diff --git a/docs/source/game_layer.rst b/docs/source/game_layer.rst index cdae17dd..af3eadc6 100644 --- a/docs/source/game_layer.rst +++ b/docs/source/game_layer.rst @@ -6,44 +6,91 @@ The Primaite codebase consists of two main modules: * ``simulator``: The simulation logic including the network topology, the network state, and behaviour of various hardware and software classes. * ``game``: The agent-training infrastructure which helps reinforcement learning agents interface with the simulation. This includes the observation, action, and rewards, for RL agents, but also scripted deterministic agents. The game layer orchestrates all the interactions between modules. - The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API. - -.. - TODO: write up these APIs and link them here. - - -Game layer ----------- +The simulator and game layer communicate using the PrimAITE State API and the PrimAITE Request API. The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components: PrimAITE Session -^^^^^^^^^^^^^^^ +================ + +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. ``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types. Agents -^^^^^^ +====== All agents inherit from the :py:class:`primaite.game.agent.interface.AbstractAgent` class, which mandates that they have an ObservationManager, ActionManager, and RewardManager. The agent behaviour depends on the type of agent, but there are two main types: * RL agents action during each step is decided by an appropriate RL algorithm. The agent within PrimAITE just acts to format and forward actions decided by an RL policy. -* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed will be settable. +* Deterministic agents perform all of their decision making within the PrimAITE game layer. They typically have a scripted policy which always performs the same action or a rule-based policy which performs actions based on the current state of the simulation. They can have a stochastic element, and their seed is settable. -.. - TODO: add seed to stochastic scripted agents Observations -^^^^^^^^^^^^^^^^^^ +============ An agent's observations are managed by the ``ObservationManager`` class. It generates observations based on the current simulation state dictionary. It also provides the observation space during initial setup. The data is formatted so it's compatible with ``Gymnasium.spaces``. Observation spaces are composed of one or more components which are defined by the ``AbstractObservation`` base class. Actions -^^^^^^^ +======= An agent's actions are managed by the ``ActionManager``. It converts actions selected by agents (which are typically integers chosen from a ``gymnasium.spaces.Discrete`` space) into simulation-friendly requests. It also provides the action space during initial setup. Action spaces are composed of one or more components which are defined by the ``AbstractAction`` base class. Rewards -^^^^^^^ +======= -An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. The reward components are defined by the AbstractReward base class. +An agent's reward function is managed by the ``RewardManager``. It calculates rewards based on the simulation state (in a way similar to observations). Rewards can be defined as a weighted sum of small reward components. For example, an agents reward can be based on the uptime of a database service plus the loss rate of packets between clients and a web server. + +Reward Components +----------------- + +Currently implemented are reward components tailored to the data manipulation scenario. View the full API and description of how they work here: :py:module:`primaite.game.agent.reward`. + +Reward Sharing +-------------- + +An agent's reward can be based on rewards of other agents. This is particularly useful for modelling a situation where the blue agent's job is to protect the ability of green agents to perform their pattern-of-life. This can be configured in the YAML file this way: + +```yaml +green_agent_1: # this agent sometimes tries to access the webpage, and sometimes the database + # actions, observations, and agent settings go here + reward_function: + reward_components: + + # When the webpage loads, the reward goes up by 0.25 when it fails to load, it goes down to -0.25 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + + # When the database is reachable, the reward goes up by 0.05, when it is unreachable it goes down to -0.05 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + +blue_agent: + # actions, observations, and agent settings go here + reward_function: + reward_components: + + # When the database file is in a good state, blue's reward is 0.4, when it's in a corrupted state the reward is -0.4 + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + # The green's reward is added onto the blue's reward. + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + +``` + +When defining agent reward sharing, users must be careful to avoid circular references, as that would lead to an infinite calculation loop. PrimAITE will prevent circular dependencies and provide a helpful error message if they are detected in the yaml. diff --git a/docs/source/getting_started.rst b/docs/source/getting_started.rst index a800ee56..bb6e0019 100644 --- a/docs/source/getting_started.rst +++ b/docs/source/getting_started.rst @@ -11,7 +11,7 @@ Getting Started Pre-Requisites -In order to get **PrimAITE** installed, you will need to have a python version between 3.8 and 3.11 installed. If you don't already have it, this is how to install it: +In order to get **PrimAITE** installed, you will need Python, venv, and pip. If you don't already have them, this is how to install it: .. code-block:: bash @@ -30,6 +30,8 @@ In order to get **PrimAITE** installed, you will need to have a python version b **PrimAITE** is designed to be OS-agnostic, and thus should work on most variations/distros of Linux, Windows, and MacOS. +Installing PrimAITE has been tested with all supported python versions, venv 20.24.1, and pip 23. + Install PrimAITE **************** @@ -38,12 +40,12 @@ Install PrimAITE .. code-block:: bash :caption: Unix - mkdir ~/primaite/3.0.0 + mkdir -p ~/primaite/{VERSION} .. code-block:: powershell :caption: Windows (Powershell) - mkdir ~\primaite\3.0.0 + mkdir ~\primaite\{VERSION} 2. Navigate to the primaite directory and create a new python virtual environment (venv) @@ -51,13 +53,13 @@ Install PrimAITE .. code-block:: bash :caption: Unix - cd ~/primaite/3.0.0 + cd ~/primaite/{VERSION} python3 -m venv .venv .. code-block:: powershell :caption: Windows (Powershell) - cd ~\primaite\3.0.0 + cd ~\primaite\{VERSION} python3 -m venv .venv attrib +h .venv /s /d # Hides the .venv directory diff --git a/docs/source/primaite_session.rst b/docs/source/primaite_session.rst index 706397b6..d0caeaad 100644 --- a/docs/source/primaite_session.rst +++ b/docs/source/primaite_session.rst @@ -4,6 +4,11 @@ .. _run a primaite session: +.. admonition:: Deprecated + :class: deprecated + + PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality. + Run a PrimAITE Session ====================== @@ -30,7 +35,7 @@ Outputs ------- Running a session creates a session output directory in your user data folder. The filepath looks like this: -``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node, +``~/primaite/{VERSION}/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node, the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that contains the action, reward, and simulation state. These can be found in -``~/primaite/3.0.0/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_/step_metadata/step_.json`` +``~/primaite/{VERSION}/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_/step_metadata/step_.json`` diff --git a/docs/source/request_system.rst b/docs/source/request_system.rst index 392bc792..fb9d3978 100644 --- a/docs/source/request_system.rst +++ b/docs/source/request_system.rst @@ -3,7 +3,7 @@ © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK Request System -============== +************** ``SimComponent`` objects in the simulation are decoupled from the agent training logic. However, they still need a managed means of accepting requests to perform actions. For this, they use ``RequestManager`` and ``RequestType``. @@ -12,31 +12,42 @@ Just like other aspects of SimComponent, the request types are not managed centr - API When requesting an action within the simulation, these two arguments must be provided: - 1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as `['network', 'node', '', 'service', '', 'restart']`. + 1. ``request`` - selects which action you want to take on this ``SimComponent``. This is formatted as a list of strings such as ``['network', 'node', '', 'service', '', 'restart']``. 2. ``context`` - optional extra information that can be used to decide how to process the request. This is formatted as a dictionary. For example, if the request requires authentication, the context can include information about the user that initiated the request to decide if their permissions are sufficient. + When a request is resolved, it returns a success status, and optional additional data about the request. + + ``status`` can be one of: + + * ``success``: the request was executed + * ``failure``: the request could not be executed + * ``unreachable``: the target for the request was not found + * ``pending``: the request was initiated, but has not finished during this step + + ``data`` can be a dictionary with any arbitrary JSON-like data to describe the outcome of the request. + - ``request`` detail The request is a list of strings which help specify who should handle the request. The strings in the request list help RequestManagers traverse the 'ownership tree' of SimComponent. The example given above would be handled in the following way: - 1. ``Simulation`` receives `['network', 'node', '', 'service', '', 'restart']`. + 1. ``Simulation`` receives ``['network', 'node', 'computer_1', 'service', 'DNSService', 'restart']``. The first element of the request is ``network``, therefore it passes the request down to its network. - 2. ``Network`` receives `['node', '', 'service', '', 'restart']`. + 2. ``Network`` receives ``['node', 'computer_1', 'service', 'DNSService', 'restart']``. The first element of the request is ``node``, therefore the network looks at the node name and passes the request down to the node with that name. - 3. ``Node`` receives `['service', '', 'restart']`. + 3. ``computer_1`` receives ``['service', 'DNSService', 'restart']``. The first element of the request is ``service``, therefore the node looks at the service name and passes the rest of the request to the service with that name. - 4. ``Service`` receives ``['restart']``. + 4. ``DNSService`` receives ``['restart']``. Since ``restart`` is a defined request type in the service's own RequestManager, the service performs a restart. - ``context`` detail The context is not used by any of the currently implemented components or requests. Technical Detail ----------------- +================ This system was achieved by implementing two classes, :py:class:`primaite.simulator.core.RequestType`, and :py:class:`primaite.simulator.core.RequestManager`. ``RequestType`` ------- +--------------- The ``RequestType`` object stores a reference to a method that executes the request, for example a node could have a request type that stores a reference to ``self.turn_on()``. Technically, this can be any callable that accepts `request, context` as it's parameters. In practice, this is often defined using ``lambda`` functions within a component's ``self._init_request_manager()`` method. Optionally, the ``RequestType`` object can also hold a validator that will permit/deny the request depending on context. @@ -60,7 +71,7 @@ A simple example without chaining can be seen in the :py:class:`primaite.simulat *ellipses (``...``) used to omit code impertinent to this explanation* Chaining RequestManagers ------------------------ +------------------------ A request function needs to be a callable that accepts ``request, context`` as parameters. Since the request manager resolves requests by invoking it with ``request, context`` as parameter, it is possible to use a ``RequestManager`` as a ``RequestType``. @@ -93,3 +104,19 @@ An example of how this works is in the :py:class:`primaite.simulator.network.har self._service_request_manager.add_request(service.name, RequestType(func=service._request_manager)) This process is repeated until the request word corresponds to a callable function rather than another ``RequestManager`` . + +Request Validation +------------------ + +There are times when a request should be rejected. For instance, if an agent attempts to run an application on a node that is currently off. For this purpose, requests are filtered by an object called a validator. :py:class:`primaite.simulator.core.RequestPermissionValidator` is a basic class whose ``__call__()`` method returns ``True`` if the request should be permitted or ``False`` if it cannot be permitted. For example, the Node class has a validator called :py:class:`primaite.simulator.network.hardware.base.Node._NodeIsOnValidator<_NodeIsOnValidator>` which allows requests only when the operating status of the node is ``ON``. + +Requests that are specified without a validator automatically get assigned an ``AllowAllValidator`` which allows requests no matter what. + +Request Response +---------------- + +The :py:class:`primaite.interface.request.RequestResponse` is a data transfer object that carries response data between the simulator and the game layer. The ``status`` field reports on the success or failure, and the ``data`` field is for any additional data. The most common way that this class is initiated is by its ``from_bool`` method. This way, given a True or False, a successful or failed request response is generated, respectively (with an empty data field). + +For instance, the ``execute`` action on a :py:class:`primaite.simulator.system.applications.web_browser.WebBrowser` calls the ``get_webpage()`` method of the ``WebBrowser``. ``get_webpage()`` returns a True if the webpage was successfully retrieved, and False if unsuccessful for any reason, such as being blocked by an ACL, or if the database server is unresponsive. The boolean returned from ``get_webpage()`` is used to create the request response. + +Just as the requests themselves were passed from owner to component, the request response is bubbled back up from component to owner until it arrives at the game layer. diff --git a/docs/source/simulation.rst b/docs/source/simulation.rst index c703b299..c4bf1bf0 100644 --- a/docs/source/simulation.rst +++ b/docs/source/simulation.rst @@ -22,9 +22,9 @@ Contents simulation_components/network/nodes/host_node simulation_components/network/nodes/network_node simulation_components/network/nodes/router + simulation_components/network/nodes/switch simulation_components/network/nodes/wireless_router simulation_components/network/nodes/firewall - simulation_components/network/switch simulation_components/network/network simulation_components/system/internal_frame_processing simulation_components/system/sys_log diff --git a/docs/source/simulation_components/network/base_hardware.rst b/docs/source/simulation_components/network/base_hardware.rst index c7545810..1b83f3f4 100644 --- a/docs/source/simulation_components/network/base_hardware.rst +++ b/docs/source/simulation_components/network/base_hardware.rst @@ -12,34 +12,81 @@ complex, specialized hardware components inherit from and build upon. The key elements defined in ``base.py`` are: -NetworkInterface -================ +``NetworkInterface`` +==================== - Abstract base class for network interfaces like NICs. Defines common attributes like MAC address, speed, MTU. - Requires subclasses to implement ``enable()``, ``disable()``, ``send_frame()`` and ``receive_frame()``. - Provides basic state description and request handling capabilities. -Node -==== +``Node`` +======== The Node class stands as a central component in ``base.py``, acting as the superclass for all network nodes within a PrimAITE simulation. - - Node Attributes --------------- +See :ref:`Node Attributes` -- **hostname**: The network hostname of the node. -- **operating_state**: Indicates the current hardware state of the node. -- **network_interfaces**: Maps interface names to NetworkInterface objects on the node. -- **network_interface**: Maps port IDs to ``NetworkInterface`` objects on the node. -- **dns_server**: Specifies DNS servers for domain name resolution. -- **start_up_duration**: The time it takes for the node to become fully operational after being powered on. -- **shut_down_duration**: The time required for the node to properly shut down. -- **sys_log**: A system log for recording events related to the node. -- **session_manager**: Manages user sessions within the node. -- **software_manager**: Controls the installation and management of software and services on the node. +.. _Node Start up and Shut down: + +Node Start up and Shut down +--------------------------- +Nodes are powered on and off over multiple timesteps. By default, the node ``start_up_duration`` and ``shut_down_duration`` is 3 timesteps. + +Example code where a node is turned on: + +.. code-block:: python + + from primaite.simulator.network.hardware.base import Node + from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState + + node = Node(hostname="pc_a") + + assert node.operating_state is NodeOperatingState.OFF # By default, node is instantiated in an OFF state + + node.power_on() # power on the node + + assert node.operating_state is NodeOperatingState.BOOTING # node is booting up + + for i in range(node.start_up_duration + 1): + # apply timestep until the node start up duration + node.apply_timestep(timestep=i) + + assert node.operating_state is NodeOperatingState.ON # node is in ON state + + +If the node needs to be instantiated in an on state: + + +.. code-block:: python + + from primaite.simulator.network.hardware.base import Node + from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState + + node = Node(hostname="pc_a", operating_state=NodeOperatingState.ON) + + assert node.operating_state is NodeOperatingState.ON # node is in ON state + +Setting ``start_up_duration`` and/or ``shut_down_duration`` to ``0`` will allow for the ``power_on`` and ``power_off`` methods to be completed instantly without applying timesteps: + +.. code-block:: python + + from primaite.simulator.network.hardware.base import Node + from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState + + node = Node(hostname="pc_a", start_up_duration=0, shut_down_duration=0) + + assert node.operating_state is NodeOperatingState.OFF # node is in OFF state + + node.power_on() + + assert node.operating_state is NodeOperatingState.ON # node is in ON state + + node.power_off() + + assert node.operating_state is NodeOperatingState.OFF # node is in OFF state Node Behaviours/Functions ------------------------- diff --git a/docs/source/simulation_components/network/network.rst b/docs/source/simulation_components/network/network.rst index 533a15f2..36e8ee48 100644 --- a/docs/source/simulation_components/network/network.rst +++ b/docs/source/simulation_components/network/network.rst @@ -30,11 +30,11 @@ we'll use the following Network that has a client, server, two switches, and a r .. code-block:: python from primaite.simulator.network.container import Network - from primaite.simulator.network.hardware.base import NIC - from primaite.simulator.network.hardware.nodes.computer import Computer - from primaite.simulator.network.hardware.nodes.router import Router, ACLAction - from primaite.simulator.network.hardware.nodes.server import Server - from primaite.simulator.network.hardware.nodes.switch import Switch + from primaite.simulator.network.hardware.base import NetworkInterface + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.network.hardware.nodes.network.router import Router, ACLAction + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port diff --git a/docs/source/simulation_components/network/network_interfaces.rst b/docs/source/simulation_components/network/network_interfaces.rst index 9e1ad80a..ffba58e4 100644 --- a/docs/source/simulation_components/network/network_interfaces.rst +++ b/docs/source/simulation_components/network/network_interfaces.rst @@ -13,6 +13,8 @@ facilitates modular development, enhances maintainability, and supports scalabil allowing for focused enhancements within each layer. .. image:: primaite_network_interface_model.png + :width: 500 + :align: center Layer Descriptions ================== @@ -65,9 +67,14 @@ Network Interface Classes **NetworkInterface (Base Layer)** -Abstract base class defining core interface properties like MAC address, speed, MTU. -Requires subclasses implement key methods like send/receive frames, enable/disable interface. -Establishes universal network interface capabilities. +- Abstract base class defining core interface properties like MAC address, speed, MTU. +- Requires subclasses implement key methods like send/receive frames, enable/disable interface. +- Establishes universal network interface capabilities. +- Malicious Network Events Monitoring: + + * Enhances network interfaces with the capability to monitor and capture Malicious Network Events (MNEs) based on predefined criteria such as specific keywords or traffic patterns. + * Integrates Number of Malicious Network Events (NMNE) detection functionalities, leveraging configurable settings like ``capture_nmne``, `nmne_capture_keywords``, and observation mechanisms such as ``NicObservation`` to classify and record network anomalies. + * Offers an additional layer of security and data analysis, crucial for identifying and mitigating malicious activities within the network infrastructure. Provides vital information for network security analysis and reinforcement learning algorithms. **WiredNetworkInterface (Connection Type Layer)** diff --git a/docs/source/simulation_components/network/nodes/firewall.rst b/docs/source/simulation_components/network/nodes/firewall.rst index 73168517..2f948081 100644 --- a/docs/source/simulation_components/network/nodes/firewall.rst +++ b/docs/source/simulation_components/network/nodes/firewall.rst @@ -229,7 +229,7 @@ To limit database server access to selected external IP addresses: position=7 ) -**Permitting DMZ Web Server Access while Blocking Specific Threats* +**Permitting DMZ Web Server Access while Blocking Specific Threats** To authorize HTTP/HTTPS access to a DMZ-hosted web server, excluding known malicious IPs: diff --git a/docs/source/simulation_components/network/nodes/network_node.rst b/docs/source/simulation_components/network/nodes/network_node.rst index eb9997ba..33bcea5b 100644 --- a/docs/source/simulation_components/network/nodes/network_node.rst +++ b/docs/source/simulation_components/network/nodes/network_node.rst @@ -27,7 +27,7 @@ in the transmission and routing of data within the simulated environment. **Key Features:** - **Frame Processing:** Central to the class is the ability to receive and process network frames, facilitating the -simulation of data flow through network devices. + simulation of data flow through network devices. - **Abstract Methods:** Includes abstract methods such as ``receive_frame``, which subclasses must implement to specify how devices handle incoming traffic. diff --git a/docs/source/simulation_components/system/data_manipulation_bot.rst b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst similarity index 63% rename from docs/source/simulation_components/system/data_manipulation_bot.rst rename to docs/source/simulation_components/system/applications/data_manipulation_bot.rst index 1fd5e5c8..9188733b 100644 --- a/docs/source/simulation_components/system/data_manipulation_bot.rst +++ b/docs/source/simulation_components/system/applications/data_manipulation_bot.rst @@ -2,20 +2,22 @@ © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +.. _DataManipulationBot: DataManipulationBot -=================== +################### -The ``DataManipulationBot`` class provides functionality to connect to a ``DatabaseService`` and execute malicious SQL statements. +The ``DataManipulationBot`` class provides functionality to connect to a :ref:`DatabaseService` and execute malicious SQL statements. Overview --------- +======== The bot is intended to simulate a malicious actor carrying out attacks like: - Dropping tables - Deleting records - Modifying data + on a database server by abusing an application's trusted database connectivity. The bot performs attacks in the following stages to simulate the real pattern of an attack: @@ -27,7 +29,7 @@ The bot performs attacks in the following stages to simulate the real pattern of Each of these stages has a random, configurable probability of succeeding (by default 10%). The bot can also be configured to repeat the attack once complete. Usage ------ +===== - Create an instance and call ``configure`` to set: - Target database server IP @@ -40,34 +42,55 @@ The bot handles connecting, executing the statement, and disconnecting. In a simulation, the bot can be controlled by using ``DataManipulationAgent`` which calls ``run`` on the bot at configured timesteps. -Example -------- +Implementation +============== + +The bot connects to a :ref:`DatabaseClient` and leverages its connectivity. The host running ``DataManipulationBot`` must also have a :ref:`DatabaseClient` installed on it. + +- Uses the Application base class for lifecycle management. +- Credentials, target IP and other options set via ``configure``. +- ``run`` handles connecting, executing statement, and disconnecting. +- SQL payload executed via ``query`` method. +- Results in malicious SQL being executed on remote database server. + + +Examples +======== + +Python +"""""" .. code-block:: python + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState + from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot + from primaite.simulator.system.applications.database_client import DatabaseClient + client_1 = Computer( hostname="client_1", ip_address="192.168.10.21", subnet_mask="255.255.255.0", - default_gateway="192.168.10.1" + default_gateway="192.168.10.1", operating_state=NodeOperatingState.ON # initialise the computer in an ON state ) network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) + client_1.software_manager.install(DatabaseClient) client_1.software_manager.install(DataManipulationBot) data_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE") data_manipulation_bot.run() -This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table. +This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to delete database contents. Example with ``DataManipulationAgent`` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +"""""""""""""""""""""""""""""""""""""" If not using the data manipulation bot manually, it needs to be used with a data manipulation agent. Below is an example section of configuration file for setting up a simulation with data manipulation bot and agent. .. code-block:: yaml - game_config: + game: # ... agents: - ref: data_manipulation_red_bot @@ -78,7 +101,7 @@ If not using the data manipulation bot manually, it needs to be used with a data type: UC2RedObservation options: nodes: - - node_ref: client_1 + - node_name: client_1 observations: - logon_status - operating_status @@ -95,7 +118,7 @@ If not using the data manipulation bot manually, it needs to be used with a data - type: NODE_APPLICATION_EXECUTE options: nodes: - - node_ref: client_1 + - node_name: client_1 applications: - application_ref: data_manipulation_bot max_folders_per_node: 1 @@ -127,14 +150,56 @@ If not using the data manipulation bot manually, it needs to be used with a data data_manipulation_p_of_success: 0.1 payload: "DELETE" server_ip: 192.168.1.14 + - ref: web_server_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 -Implementation --------------- +Configuration +============= -The bot extends ``DatabaseClient`` and leverages its connectivity. +.. include:: ../common/common_configuration.rst -- Uses the Application base class for lifecycle management. -- Credentials, target IP and other options set via ``configure``. -- ``run`` handles connecting, executing statement, and disconnecting. -- SQL payload executed via ``query`` method. -- Results in malicious SQL being executed on remote database server. +.. |SOFTWARE_NAME| replace:: DataManipulationBot +.. |SOFTWARE_NAME_BACKTICK| replace:: ``DataManipulationBot`` + +``server_ip`` +""""""""""""" + +IP address of the :ref:`DatabaseService` which the ``DataManipulationBot`` will try to attack. + +This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. + +``server_password`` +""""""""""""""""""" + +Optional. Default value is ``None``. + +The password that the ``DataManipulationBot`` will use to access the :ref:`DatabaseService`. + +``payload`` +""""""""""" + +Optional. Default value is ``DELETE``. + +The payload that the ``DataManipulationBot`` will send to the :ref:`DatabaseService`. + +.. include:: ../common/db_payload_list.rst + +``port_scan_p_of_success`` +"""""""""""""""""""""""""" + +Optional. Default value is ``0.1``. + +The chance of the ``DataManipulationBot`` to succeed with a port scan (and therefore continue the attack). + +This must be a float value between ``0`` and ``1``. + +``data_manipulation_p_of_success`` +"""""""""""""""""""""""""""""""""" + +Optional. Default value is ``0.1``. + +The chance of the ``DataManipulationBot`` to succeed with a data manipulation attack. + +This must be a float value between ``0`` and ``1``. diff --git a/docs/source/simulation_components/system/applications/database_client.rst b/docs/source/simulation_components/system/applications/database_client.rst new file mode 100644 index 00000000..ddf6db11 --- /dev/null +++ b/docs/source/simulation_components/system/applications/database_client.rst @@ -0,0 +1,106 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _DatabaseClient: + +DatabaseClient +############## + +The ``DatabaseClient`` provides a client interface for connecting to the :ref:`DatabaseService`. + +Key features +============ + +- Connects to the :ref:`DatabaseService` via the ``SoftwareManager``. +- Handles connecting and disconnecting. +- Executes SQL queries and retrieves result sets. + +Usage +===== + +- Initialise with server IP address and optional password. +- Connect to the :ref:`DatabaseService` with ``connect``. +- Retrieve results in a dictionary. +- Disconnect when finished. + +Implementation +============== + +- Leverages ``SoftwareManager`` for sending payloads over the network. +- Connect and disconnect methods manage sessions. +- Payloads serialised as dictionaries for transmission. +- Extends base Application class. + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from ipaddress import IPv4Address + + from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.system.applications.database_client import DatabaseClient + + client = Computer( + hostname="client", + ip_address="192.168.10.21", + subnet_mask="255.255.255.0", + default_gateway="192.168.10.1", + operating_state=NodeOperatingState.ON # initialise the computer in an ON state + ) + + # install DatabaseClient + client.software_manager.install(DatabaseClient) + + database_client: DatabaseClient = client.software_manager.software.get("DatabaseClient") + + # Configure the DatabaseClient + database_client.configure(server_ip_address=IPv4Address("192.168.0.1")) # address of the DatabaseService + database_client.run() + + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_computer + hostname: example_computer + type: computer + ... + applications: + - ref: database_client + type: DatabaseClient + options: + db_server_ip: 192.168.0.1 + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: DatabaseClient +.. |SOFTWARE_NAME_BACKTICK| replace:: ``DatabaseClient`` + + +``db_server_ip`` +"""""""""""""""" + +IP address of the :ref:`DatabaseService` that the ``DatabaseClient`` will connect to + +This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. + +``server_password`` +""""""""""""""""""" + +Optional. Default value is ``None``. + +The password that the ``DatabaseClient`` will use to access the :ref:`DatabaseService`. diff --git a/docs/source/simulation_components/system/applications/dos_bot.rst b/docs/source/simulation_components/system/applications/dos_bot.rst new file mode 100644 index 00000000..6ddbac72 --- /dev/null +++ b/docs/source/simulation_components/system/applications/dos_bot.rst @@ -0,0 +1,160 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _DoSBot: + +DoSBot +###### + +The ``DoSBot`` is an implementation of a Denial of Service attack within the PrimAITE simulation. This specifically simulates a `Slow Loris attack `. + +Key features +============ + +- Connects to the :ref:`DatabaseService` via the ``SoftwareManager``. +- Makes many connections to the :ref:`DatabaseService` which ends up using up the available connections. + +Usage +===== + +- Configure with target IP address and optional password. +- use ``run`` to run the application_loop of DoSBot to begin attacks +- DoSBot runs through different actions at each timestep + +Implementation +============== + +- Leverages :ref:`DatabaseClient` to create connections with :ref`DatabaseServer`. +- Extends base Application class. + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from ipaddress import IPv4Address + + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.system.applications.red_applications.dos_bot import DoSBot + + # Create Computer + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + computer.power_on() + + # Install DoSBot on computer + computer.software_manager.install(DoSBot) + dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") + + # Configure the DoSBot + dos_bot.configure( + target_ip_address=IPv4Address("192.168.0.10"), + payload="SPOOF DATA", + repeat=False, + port_scan_p_of_success=0.8, + dos_intensity=1.0, + max_sessions=1000 + ) + + # run DoSBot + dos_bot.run() + + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_computer + hostname: example_computer + type: computer + ... + applications: + - ref: dos_bot + type: DoSBot + options: + target_ip_address: 192.168.0.10 + payload: SPOOF DATA + repeat: False + port_scan_p_of_success: 0.8 + dos_intensity: 1.0 + max_sessions: 1000 + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: DoSBot +.. |SOFTWARE_NAME_BACKTICK| replace:: ``DoSBot`` + +``target_ip_address`` +""""""""""""""""""""" + +IP address of the :ref:`DatabaseService` which the ``DataManipulationBot`` will try to attack. + +This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. + +``target_port`` +""""""""""""""" + +Optional. Default value is ``5432``. + +Port of the target service. + +See :ref:`List of IPProtocols ` for a list of protocols. + +``payload`` +""""""""""" + +Optional. Default value is ``None``. + +The payload that the ``DoSBot`` sends as part of its attack. + +.. include:: ../common/db_payload_list.rst + +``repeat`` +"""""""""" + +Optional. Default value is ``False``. + +If ``True`` the ``DoSBot`` will maintain its attack. + +``port_scan_p_of_success`` +"""""""""""""""""""""""""" + +Optional. Default value is ``0.1``. + +The chance of the ``DoSBot`` to succeed with a port scan (and therefore continue the attack). + +This must be a float value between ``0`` and ``1``. + +``dos_intensity`` +""""""""""""""""" + +Optional. Default value is ``1.0``. + +The intensity of the Denial of Service attack. This is multiplied by the number of ``max_sessions``. + +This must be a float value between ``0`` and ``1``. + +``max_sessions`` +"""""""""""""""" + +Optional. Default value is ``1000``. + +The maximum number of sessions the ``DoSBot`` is able to make. + +This must be an integer value equal to or greater than ``0``. diff --git a/docs/source/simulation_components/system/applications/web_browser.rst b/docs/source/simulation_components/system/applications/web_browser.rst new file mode 100644 index 00000000..c46089ba --- /dev/null +++ b/docs/source/simulation_components/system/applications/web_browser.rst @@ -0,0 +1,111 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _WebBrowser: + +WebBrowser +########## + +The ``WebBrowser`` provides a client interface for connecting to the :ref:`WebServer`. + +Key features +============ + +- Connects to the :ref:`WebServer` via the ``SoftwareManager``. +- Simulates HTTP requests and HTTP packet transfer across a network +- Allows the emulation of HTTP requests between client and server: + - Automatically uses ``DNSClient`` to resolve domain names + - GET: performs an HTTP GET request from client to server +- Leverages the Service base class for install/uninstall, status tracking, etc. + +Usage +===== + +- Install on a Node via the ``SoftwareManager`` to start the ``WebBrowser``. +- Service runs on HTTP port 80 by default. (TODO: HTTPS) +- Execute sending an HTTP GET request with ``get_webpage`` + +Implementation +============== + +- Leverages ``SoftwareManager`` for sending payloads over the network. +- Provides easy interface for making HTTP requests between an HTTP client and server. +- Extends base Service class. + + +Examples +======== + +Python +"""""" + +The ``WebBrowser`` utilises :ref:`DNSClient` and :ref:`DNSServer` to resolve a URL. + +The :ref:`DNSClient` must be configured to use the :ref:`DNSServer`. The :ref:`DNSServer` should be configured to have the ``WebBrowser`` ``target_url`` within its ``domain_mapping``. + +.. code-block:: python + + from primaite.simulator.network.hardware.nodes.host.computer import Computer + from primaite.simulator.system.applications.web_browser import WebBrowser + + # Create Computer + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + computer.power_on() + + # Install WebBrowser on computer + computer.software_manager.install(WebBrowser) + web_browser: WebBrowser = computer.software_manager.software.get("WebBrowser") + web_browser.run() + + # configure the WebBrowser + web_browser.target_url = "arcd.com" + + # once DNS server is configured with the correct domain mapping + # this should work + web_browser.get_webpage() + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_computer + hostname: example_computer + type: computer + ... + applications: + - ref: web_browser + type: WebBrowser + options: + target_url: http://arcd.com/ + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: WebBrowser +.. |SOFTWARE_NAME_BACKTICK| replace:: ``WebBrowser`` + +``target_url`` +"""""""""""""" + +The URL that the ``WebBrowser`` will request when ``get_webpage`` is called without parameters. + +The URL can be in any format so long as the domain is within it e.g. + +The domain ``arcd.com`` can be matched by + +- http://arcd.com/ +- http://arcd.com/users/ +- arcd.com diff --git a/docs/source/simulation_components/system/common/common_configuration.rst b/docs/source/simulation_components/system/common/common_configuration.rst new file mode 100644 index 00000000..27625407 --- /dev/null +++ b/docs/source/simulation_components/system/common/common_configuration.rst @@ -0,0 +1,18 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +``ref`` +======= + +Human readable name used as reference for the |SOFTWARE_NAME_BACKTICK|. Not used in code. + +``type`` +======== + +The type of software that should be added. To add |SOFTWARE_NAME| this must be |SOFTWARE_NAME_BACKTICK|. + +``options`` +=========== + +The configuration options are the attributes that fall under the options for an application. diff --git a/docs/source/simulation_components/system/common/db_payload_list.rst b/docs/source/simulation_components/system/common/db_payload_list.rst new file mode 100644 index 00000000..f51227c6 --- /dev/null +++ b/docs/source/simulation_components/system/common/db_payload_list.rst @@ -0,0 +1,11 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _Database Payload List: + +Available Database Payloads: + +- ``SELECT`` +- ``INSERT`` +- ``DELETE`` diff --git a/docs/source/simulation_components/system/database_client_server.rst b/docs/source/simulation_components/system/database_client_server.rst deleted file mode 100644 index 07912f3e..00000000 --- a/docs/source/simulation_components/system/database_client_server.rst +++ /dev/null @@ -1,71 +0,0 @@ -.. only:: comment - - © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - - -Database Client Server -====================== - -Database Service ----------------- - -The ``DatabaseService`` provides a SQL database server simulation by extending the base Service class. - -Key capabilities -^^^^^^^^^^^^^^^^ - -- Creates a database file in the ``Node`` 's ``FileSystem`` upon creation. -- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs. -- Authenticates connections using a configurable password. -- Simulates ``SELECT``, ``DELETE`` and ``INSERT`` SQL queries. -- Returns query results and status codes back to clients. -- Leverages the Service base class for install/uninstall, status tracking, etc. - -Usage -^^^^^ -- Install on a Node via the ``SoftwareManager`` to start the database service. -- Clients connect, execute queries, and disconnect. -- Service runs on TCP port 5432 by default. - -Implementation -^^^^^^^^^^^^^^ - -- Creates the database file within the node's file system. -- Manages client connections in a dictionary by session ID. -- Processes SQL queries. -- Returns results and status codes in a standard dictionary format. -- Extends Service class for integration with ``SoftwareManager``. - -Database Client ---------------- - -The DatabaseClient provides a client interface for connecting to the ``DatabaseService``. - -Key features -^^^^^^^^^^^^ - -- Connects to the ``DatabaseService`` via the ``SoftwareManager``. -- Handles connecting and disconnecting. -- Executes SQL queries and retrieves result sets. - -Usage -^^^^^ - -- Initialise with server IP address and optional password. -- Connect to the ``DatabaseService`` with ``connect``. -- Retrieve results in a dictionary. -- Disconnect when finished. - -To create database backups: - -- Configure the backup server on the ``DatabaseService`` by providing the Backup server ``IPv4Address`` with ``configure_backup`` -- Create a backup using ``backup_database``. This fails if the backup server is not configured. -- Restore a backup using ``restore_backup``. By default, this uses the database created via ``backup_database``. - -Implementation -^^^^^^^^^^^^^^ - -- Leverages ``SoftwareManager`` for sending payloads over the network. -- Connect and disconnect methods manage sessions. -- Payloads serialised as dictionaries for transmission. -- Extends base Application class. diff --git a/docs/source/simulation_components/system/dns_client_server.rst b/docs/source/simulation_components/system/dns_client_server.rst deleted file mode 100644 index f57f903b..00000000 --- a/docs/source/simulation_components/system/dns_client_server.rst +++ /dev/null @@ -1,56 +0,0 @@ -.. only:: comment - - © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - -DNS Client Server -================= - -DNS Server ----------- -Also known as a DNS Resolver, the ``DNSServer`` provides a DNS Server simulation by extending the base Service class. - -Key capabilities -^^^^^^^^^^^^^^^^ - -- Simulates DNS requests and DNSPacket transfer across a network -- Registers domain names and the IP Address linked to the domain name -- Returns the IP address for a given domain name within a DNS Packet that a DNS Client can read -- Leverages the Service base class for install/uninstall, status tracking, etc. - -Usage -^^^^^ -- Install on a Node via the ``SoftwareManager`` to start the database service. -- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future) - -Implementation -^^^^^^^^^^^^^^ - -- DNS request and responses use a ``DNSPacket`` object -- Extends Service class for integration with ``SoftwareManager``. - -DNS Client ----------- - -The DNSClient provides a client interface for connecting to the ``DNSServer``. - -Key features -^^^^^^^^^^^^ - -- Connects to the ``DNSServer`` via the ``SoftwareManager``. -- Executes DNS lookup requests and keeps a cache of known domain name IP addresses. -- Handles connection to DNSServer and querying for domain name IP addresses. - -Usage -^^^^^ - -- Install on a Node via the ``SoftwareManager`` to start the database service. -- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future) -- Execute domain name checks with ``check_domain_exists``. -- ``DNSClient`` will automatically add the IP Address of the domain into its cache - -Implementation -^^^^^^^^^^^^^^ - -- Leverages ``SoftwareManager`` for sending payloads over the network. -- Provides easy interface for Nodes to find IP addresses via domain names. -- Extends base Service class. diff --git a/docs/source/simulation_components/system/ftp_client_server.rst b/docs/source/simulation_components/system/ftp_client_server.rst deleted file mode 100644 index a544b4c8..00000000 --- a/docs/source/simulation_components/system/ftp_client_server.rst +++ /dev/null @@ -1,135 +0,0 @@ -.. only:: comment - - © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - -FTP Client Server -================= - -FTP Server ----------- -Provides a FTP Client-Server simulation by extending the base Service class. - -Key capabilities -^^^^^^^^^^^^^^^^ - -- Simulates FTP requests and FTPPacket transfer across a network -- Allows the emulation of FTP commands between an FTP client and server: - - STOR: stores a file from client to server - - RETR: retrieves a file from the FTP server -- Leverages the Service base class for install/uninstall, status tracking, etc. - -Usage -^^^^^ -- Install on a Node via the ``SoftwareManager`` to start the FTP server service. -- Service runs on FTP (command) port 21 by default. (TODO: look at in depth implementation of FTP PORT command) - -Implementation -^^^^^^^^^^^^^^ - -- FTP request and responses use a ``FTPPacket`` object -- Extends Service class for integration with ``SoftwareManager``. - -FTP Client ----------- - -The ``FTPClient`` provides a client interface for connecting to the ``FTPServer``. - -Key features -^^^^^^^^^^^^ - -- Connects to the ``FTPServer`` via the ``SoftwareManager``. -- Simulates FTP requests and FTPPacket transfer across a network -- Allows the emulation of FTP commands between an FTP client and server: - - PORT: specifies the port that server should connect to on the client (currently only uses ``Port.FTP``) - - STOR: stores a file from client to server - - RETR: retrieves a file from the FTP server - - QUIT: disconnect from server -- Leverages the Service base class for install/uninstall, status tracking, etc. - -Usage -^^^^^ - -- Install on a Node via the ``SoftwareManager`` to start the FTP client service. -- Service runs on FTP (command) port 21 by default. (TODO: look at in depth implementation of FTP PORT command) -- Execute sending a file to the FTP server with ``send_file`` -- Execute retrieving a file from the FTP server with ``request_file`` - -Implementation -^^^^^^^^^^^^^^ - -- Leverages ``SoftwareManager`` for sending payloads over the network. -- Provides easy interface for Nodes to transfer files between each other. -- Extends base Service class. - - -Example Usage -------------- - -Dependencies -^^^^^^^^^^^^ - -.. code-block:: python - - from ipaddress import IPv4Address - - from primaite.simulator.network.container import Network - from primaite.simulator.network.hardware.nodes.computer import Computer - from primaite.simulator.network.hardware.nodes.server import Server - from primaite.simulator.system.services.ftp.ftp_server import FTPServer - from primaite.simulator.system.services.ftp.ftp_client import FTPClient - from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState - -Example peer to peer network -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: python - - net = Network() - - pc1 = Computer( - hostname="pc1", - ip_address="120.10.10.10", - subnet_mask="255.255.255.0", - operating_state=NodeOperatingState.ON # initialise the computer in an ON state - ) - srv = Server( - hostname="srv", - ip_address="120.10.10.20", - subnet_mask="255.255.255.0", - operating_state=NodeOperatingState.ON # initialise the server in an ON state - ) - net.connect(pc1.network_interface[1], srv.network_interface[1]) - -Install the FTP Server -^^^^^^^^^^^^^^^^^^^^^^ - -FTP Client should be pre installed on nodes - -.. code-block:: python - - srv.software_manager.install(FTPServer) - ftpserv: FTPServer = srv.software_manager.software['FTPServer'] - -Setting up the FTP Server -^^^^^^^^^^^^^^^^^^^^^^^^^ - -Set up the FTP Server with a file that the client will need to retrieve - -.. code-block:: python - - srv.file_system.create_file('my_file.png') - -Check that file was retrieved -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: python - - client.request_file( - src_folder_name='root', - src_file_name='my_file.png', - dest_folder_name='root', - dest_file_name='test.png', - dest_ip_address=IPv4Address("120.10.10.20") - ) - - print(client.get_file(folder_name="root", file_name="test.png")) diff --git a/docs/source/simulation_components/system/list_of_applications.rst b/docs/source/simulation_components/system/list_of_applications.rst new file mode 100644 index 00000000..8f792e4c --- /dev/null +++ b/docs/source/simulation_components/system/list_of_applications.rst @@ -0,0 +1,15 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. toctree:: + :maxdepth: 1 + :glob: + + applications/* + +More info :py:mod:`primaite.game.game.APPLICATION_TYPES_MAPPING` + +.. include:: list_of_system_applications.rst + +.. |SOFTWARE_TYPE| replace:: application diff --git a/docs/source/simulation_components/system/list_of_services.rst b/docs/source/simulation_components/system/list_of_services.rst new file mode 100644 index 00000000..9f1c9fe2 --- /dev/null +++ b/docs/source/simulation_components/system/list_of_services.rst @@ -0,0 +1,15 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. toctree:: + :maxdepth: 1 + :glob: + + services/* + +More info :py:mod:`primaite.game.game.SERVICE_TYPES_MAPPING` + +.. include:: list_of_system_services.rst + +.. |SOFTWARE_TYPE| replace:: service diff --git a/docs/source/simulation_components/system/list_of_system_applications.rst b/docs/source/simulation_components/system/list_of_system_applications.rst new file mode 100644 index 00000000..193b3dc6 --- /dev/null +++ b/docs/source/simulation_components/system/list_of_system_applications.rst @@ -0,0 +1,16 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +``system applications`` +""""""""""""""""""""""" + +Some applications are pre installed on nodes - this is similar to how some applications are included with the Operating System. + +The application may not be configured as needed, in which case, see the relevant application page. + +The list of applications that are considered system software are: + +- ``WebBrowser`` + +More info :py:mod:`primaite.simulator.network.hardware.nodes.host.host_node.HostNode.SYSTEM_SOFTWARE` diff --git a/docs/source/simulation_components/system/list_of_system_services.rst b/docs/source/simulation_components/system/list_of_system_services.rst new file mode 100644 index 00000000..5acfc12e --- /dev/null +++ b/docs/source/simulation_components/system/list_of_system_services.rst @@ -0,0 +1,18 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +``system services`` +""""""""""""""""""" + +Some services are pre installed on nodes - this is similar to how some services are included with the Operating System. + +The service may not be configured as needed, in which case, see the relevant service page. + +The list of services that are considered system software are: + +- ``DNSClient`` +- ``FTPClient`` +- ``NTPClient`` + +More info :py:mod:`primaite.simulator.network.hardware.nodes.host.host_node.HostNode.SYSTEM_SOFTWARE` diff --git a/docs/source/simulation_components/system/ntp_client_server.rst b/docs/source/simulation_components/system/ntp_client_server.rst deleted file mode 100644 index b6d57c13..00000000 --- a/docs/source/simulation_components/system/ntp_client_server.rst +++ /dev/null @@ -1,54 +0,0 @@ -.. only:: comment - - © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - -NTP Client Server -================= - -NTP Server ----------- -The ``NTPServer`` provides a NTP Server simulation by extending the base Service class. - -NTP Client ----------- -The ``NTPClient`` provides a NTP Client simulation by extending the base Service class. - -Key capabilities -^^^^^^^^^^^^^^^^ - -- Simulates NTP requests and NTPPacket transfer across a network -- Leverages the Service base class for install/uninstall, status tracking, etc. - -Usage -^^^^^ -- Install on a Node via the ``SoftwareManager`` to start the database service. -- Service runs on UDP port 123 by default. - -Implementation -^^^^^^^^^^^^^^ - -- NTP request and responses use a ``NTPPacket`` object -- Extends Service class for integration with ``SoftwareManager``. - -NTP Client ----------- - -The NTPClient provides a client interface for connecting to the ``NTPServer``. - -Key features -^^^^^^^^^^^^ - -- Connects to the ``NTPServer`` via the ``SoftwareManager``. - -Usage -^^^^^ - -- Install on a Node via the ``SoftwareManager`` to start the database service. -- Service runs on UDP port 123 by default. - -Implementation -^^^^^^^^^^^^^^ - -- Leverages ``SoftwareManager`` for sending payloads over the network. -- Provides easy interface for Nodes to find IP addresses via domain names. -- Extends base Service class. diff --git a/docs/source/simulation_components/system/services/database_service.rst b/docs/source/simulation_components/system/services/database_service.rst new file mode 100644 index 00000000..dd6dec41 --- /dev/null +++ b/docs/source/simulation_components/system/services/database_service.rst @@ -0,0 +1,116 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _DatabaseService: + +DatabaseService +############### + +The ``DatabaseService`` provides a SQL database server simulation by extending the base Service class. + +Key capabilities +================ + +- Creates a database file in the ``FileSystem`` of the ``Node`` (which the ``DatabaseService`` is installed on) upon creation. +- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs. +- Authenticates connections using a configurable password. +- Simulates ``SELECT``, ``DELETE`` and ``INSERT`` SQL queries. +- Returns query results and status codes back to clients. +- Leverages the Service base class for install/uninstall, status tracking, etc. + +Usage +===== +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Clients connect, execute queries, and disconnect. +- Service runs on TCP port 5432 by default. + +**Supported queries:** + +* ``SELECT``: As long as the database file is in a ``GOOD`` health state, the db service will respond with a 200 status code. +* ``DELETE``: This query represents an attack, it will cause the database file to enter a ``COMPROMISED`` state, and return a status code 200. +* ``INSERT``: If the database service is in a healthy state, this will return a 200 status, if it's not in a healthy state it will return 404. +* ``SELECT * FROM pg_stat_activity``: This query represents something an admin would send to check the status of the database. If the database service is in a healthy state, it returns a 200 status code, otherwise a 401 status code. + +Implementation +============== + +- Creates the database file within the node's file system. +- Manages client connections in a dictionary by session ID. +- Processes SQL queries. +- Returns results and status codes in a standard dictionary format. +- Extends Service class for integration with ``SoftwareManager``. + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from ipaddress import IPv4Address + + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.system.services.database.database_service import DatabaseService + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Install DatabaseService on server + server.software_manager.install(DatabaseService) + db_service: DatabaseService = server.software_manager.software.get("DatabaseService") + db_service.start() + + # configure DatabaseService + db_service.configure_backup(IPv4Address("192.168.0.10")) + + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_server + hostname: example_server + type: server + ... + services: + - ref: database_service + type: DatabaseService + options: + backup_server_ip: 192.168.0.10 + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: DatabaseService +.. |SOFTWARE_NAME_BACKTICK| replace:: ``DatabaseService`` + +``backup_server_ip`` +"""""""""""""""""""" + +Optional. Default value is ``None``. + +The IP Address of the backup server that the ``DatabaseService`` will use to create backups of the database. + +This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. + +``password`` +"""""""""""" + +Optional. Default value is ``None``. + +The password that needs to be provided by connecting clients in order to create a successful connection. diff --git a/docs/source/simulation_components/system/services/dns_client.rst b/docs/source/simulation_components/system/services/dns_client.rst new file mode 100644 index 00000000..91461590 --- /dev/null +++ b/docs/source/simulation_components/system/services/dns_client.rst @@ -0,0 +1,99 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _DNSClient: + +DNSClient +######### + +The DNSClient provides a client interface for connecting to the :ref:`DNSServer`. + +Key features +============ + +- Connects to the :ref:`DNSServer` via the ``SoftwareManager``. +- Executes DNS lookup requests and keeps a cache of known domain name IP addresses. +- Handles connection to DNSServer and querying for domain name IP addresses. + +Usage +===== + +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future) +- Execute domain name checks with ``check_domain_exists``. +- ``DNSClient`` will automatically add the IP Address of the domain into its cache + +Implementation +============== + +- Leverages ``SoftwareManager`` for sending payloads over the network. +- Provides easy interface for Nodes to find IP addresses via domain names. +- Extends base Service class. + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from ipaddress import IPv4Address + + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.system.services.dns.dns_client import DNSClient + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Install DNSClient on server + server.software_manager.install(DNSClient) + dns_client: DNSClient = server.software_manager.software.get("DNSClient") + dns_client.start() + + # configure DatabaseService + dns_client.dns_server = IPv4Address("192.168.0.10") + + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_server + hostname: example_server + type: server + ... + services: + - ref: dns_client + type: DNSClient + options: + dns_server: 192.168.0.10 + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: DNSClient +.. |SOFTWARE_NAME_BACKTICK| replace:: ``DNSClient`` + +``dns_server`` +"""""""""""""" + +Optional. Default value is ``None``. + +The IP Address of the :ref:`DNSServer`. + +This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. diff --git a/docs/source/simulation_components/system/services/dns_server.rst b/docs/source/simulation_components/system/services/dns_server.rst new file mode 100644 index 00000000..89ce7fc1 --- /dev/null +++ b/docs/source/simulation_components/system/services/dns_server.rst @@ -0,0 +1,98 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _DNSServer: + +DNSServer +######### + +Also known as a DNS Resolver, the ``DNSServer`` provides a DNS Server simulation by extending the base Service class. + +Key capabilities +================ + +- Simulates DNS requests and DNSPacket transfer across a network +- Registers domain names and the IP Address linked to the domain name +- Returns the IP address for a given domain name within a DNS Packet that a DNS Client can read +- Leverages the Service base class for install/uninstall, status tracking, etc. + +Usage +===== +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Service runs on TCP port 53 by default. (TODO: TCP for now, should be UDP in future) + +Implementation +============== + +- DNS request and responses use a ``DNSPacket`` object +- Extends Service class for integration with ``SoftwareManager``. + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from ipaddress import IPv4Address + + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.system.services.dns.dns_server import DNSServer + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Install DNSServer on server + server.software_manager.install(DNSServer) + dns_server: DNSServer = server.software_manager.software.get("DNSServer") + dns_server.start() + + # configure DatabaseService + dns_server.dns_register("arcd.com", IPv4Address("192.168.10.10")) + + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_server + hostname: example_server + type: server + ... + services: + - ref: dns_server + type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.0.10 + another-example.com: 192.168.10.10 + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: DNSServer +.. |SOFTWARE_NAME_BACKTICK| replace:: ``DNSServer`` + +domain_mapping +"""""""""""""" + +Domain mapping takes the domain and IP Addresses as a key-value pairs i.e. + +If the domain is "arcd.com" and the IP Address attributed to the domain is 192.168.0.10, then the value should be ``arcd.com: 192.168.0.10`` + +The key must be a string and the IP Address must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. diff --git a/docs/source/simulation_components/system/services/ftp_client.rst b/docs/source/simulation_components/system/services/ftp_client.rst new file mode 100644 index 00000000..259a626d --- /dev/null +++ b/docs/source/simulation_components/system/services/ftp_client.rst @@ -0,0 +1,91 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _FTPClient: + +FTPClient +######### + +The ``FTPClient`` provides a client interface for connecting to the :ref:`FTPServer`. + +Key features +============ + +- Connects to the :ref:`FTPServer` via the ``SoftwareManager``. +- Simulates FTP requests and FTPPacket transfer across a network +- Allows the emulation of FTP commands between an FTP client and server: + - PORT: specifies the port that server should connect to on the client (currently only uses ``Port.FTP``) + - STOR: stores a file from client to server + - RETR: retrieves a file from the FTP server + - QUIT: disconnect from server +- Leverages the Service base class for install/uninstall, status tracking, etc. +- :ref:`FTPClient` and ``FTPServer`` utilise port 21 (FTP) throughout all file transfer / request + +Usage +===== + +- Install on a Node via the ``SoftwareManager`` to start the FTP client service. +- Service runs on FTP (command) port 21 by default +- Execute sending a file to the FTP server with ``send_file`` +- Execute retrieving a file from the FTP server with ``request_file`` + +Implementation +============== + +- Leverages ``SoftwareManager`` for sending payloads over the network. +- Provides easy interface for Nodes to transfer files between each other. +- Extends base Service class. + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.system.services.ftp.ftp_client import FTPClient + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.10", + start_up_duration=0, + ) + server.power_on() + + # Install FTPClient on server + server.software_manager.install(FTPClient) + ftp_client: FTPClient = server.software_manager.software.get("FTPClient") + ftp_client.start() + + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_server + hostname: example_server + type: server + ... + services: + - ref: ftp_client + type: FTPClient + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: FTPClient +.. |SOFTWARE_NAME_BACKTICK| replace:: ``FTPClient`` + +**FTPClient has no configuration options** diff --git a/docs/source/simulation_components/system/services/ftp_server.rst b/docs/source/simulation_components/system/services/ftp_server.rst new file mode 100644 index 00000000..fb57a762 --- /dev/null +++ b/docs/source/simulation_components/system/services/ftp_server.rst @@ -0,0 +1,94 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _FTPServer: + +FTPServer +######### + +Provides a FTP Client-Server simulation by extending the base Service class. + +Key capabilities +================ + +- Simulates FTP requests and FTPPacket transfer across a network +- Allows the emulation of FTP commands between an FTP client and server: + - STOR: stores a file from client to server + - RETR: retrieves a file from the FTP server +- Leverages the Service base class for install/uninstall, status tracking, etc. +- :ref:`FTPClient` and ``FTPServer`` utilise port 21 (FTP) throughout all file transfer / request + +Usage +===== + +- Install on a Node via the ``SoftwareManager`` to start the FTP server service. +- Service runs on FTP (command) port 21 by default + +Implementation +============== + +- FTP request and responses use a ``FTPPacket`` object +- Extends Service class for integration with ``SoftwareManager``. + + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.system.services.ftp.ftp_server import FTPServer + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Install FTPServer on server + server.software_manager.install(FTPServer) + ftp_server: FTPServer = server.software_manager.software.get("FTPServer") + ftp_server.start() + + ftp_server.server_password = "test" + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_server + hostname: example_server + type: server + ... + services: + - ref: ftp_server + type: FTPServer + options: + server_password: test + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: FTPServer +.. |SOFTWARE_NAME_BACKTICK| replace:: ``FTPServer`` + +``server_password`` +""""""""""""""""""" + +Optional. Default value is ``None``. + +The password that needs to be provided by a connecting :ref:`FTPClient` in order to create a successful connection. diff --git a/docs/source/simulation_components/system/services/ntp_client.rst b/docs/source/simulation_components/system/services/ntp_client.rst new file mode 100644 index 00000000..aaba3261 --- /dev/null +++ b/docs/source/simulation_components/system/services/ntp_client.rst @@ -0,0 +1,95 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _NTPClient: + +NTPClient +######### + +The NTPClient provides a client interface for connecting to the ``NTPServer``. + +Key features +============ + +- Connects to the ``NTPServer`` via the ``SoftwareManager``. + +Usage +===== + +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Service runs on UDP port 123 by default. + +Implementation +============== + +- Leverages ``SoftwareManager`` for sending payloads over the network. +- Provides easy interface for Nodes to find IP addresses via domain names. +- Extends base Service class. + + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from ipaddress import IPv4Address + + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.system.services.ntp.ntp_client import NTPClient + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Install NTPClient on server + server.software_manager.install(NTPClient) + ntp_client: NTPClient = server.software_manager.software.get("NTPClient") + ntp_client.start() + + ntp_client.configure(ntp_server_ip_address=IPv4Address("192.168.0.10")) + + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_server + hostname: example_server + type: server + ... + services: + - ref: ntp_client + type: NTPClient + options: + ntp_server_ip: 192.168.0.10 + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: NTPClient +.. |SOFTWARE_NAME_BACKTICK| replace:: ``NTPClient`` + +``ntp_server_ip`` +""""""""""""""""" + +Optional. Default value is ``None``. + +The IP address of an NTP Server which provides a time that the ``NTPClient`` can synchronise to. + +This must be a valid octet i.e. in the range of ``0.0.0.0`` and ``255.255.255.255``. diff --git a/docs/source/simulation_components/system/services/ntp_server.rst b/docs/source/simulation_components/system/services/ntp_server.rst new file mode 100644 index 00000000..0025b428 --- /dev/null +++ b/docs/source/simulation_components/system/services/ntp_server.rst @@ -0,0 +1,86 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _NTPServer: + +NTPServer +######### + +The ``NTPServer`` provides a NTP Server simulation by extending the base Service class. + +NTP Client +========== + +The ``NTPClient`` provides a NTP Client simulation by extending the base Service class. + +Key capabilities +================ + +- Simulates NTP requests and NTPPacket transfer across a network +- Leverages the Service base class for install/uninstall, status tracking, etc. + +Usage +===== +- Install on a Node via the ``SoftwareManager`` to start the database service. +- Service runs on UDP port 123 by default. + +Implementation +============== + +- NTP request and responses use a ``NTPPacket`` object +- Extends Service class for integration with ``SoftwareManager``. + + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Install NTPServer on server + server.software_manager.install(NTPServer) + ntp_server: NTPServer = server.software_manager.software.get("NTPServer") + ntp_server.start() + + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_server + hostname: example_server + type: server + ... + services: + - ref: ntp_server + type: NTPServer + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: NTPServer +.. |SOFTWARE_NAME_BACKTICK| replace:: ``NTPServer`` + +**NTPServer has no configuration options** diff --git a/docs/source/simulation_components/system/services/web_server.rst b/docs/source/simulation_components/system/services/web_server.rst new file mode 100644 index 00000000..62b1d090 --- /dev/null +++ b/docs/source/simulation_components/system/services/web_server.rst @@ -0,0 +1,86 @@ +.. only:: comment + + © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK + +.. _WebServer: + +WebServer +######### + +Provides a Web Server simulation by extending the base Service class. + +Key capabilities +================ + +- Simulates a web server with the capability to also request data from a database +- Allows the emulation of HTTP requests between client (e.g. a web browser) and server + - GET request sends a get all users request to the database server and returns an HTTP 200 status if the database is responsive +- Leverages the Service base class for install/uninstall, status tracking, etc. + +Usage +===== + +- Install on a Node via the ``SoftwareManager`` to start the `WebServer`. +- Service runs on HTTP port 80 by default. (TODO: HTTPS) +- A :ref:`DatabaseClient` must be installed and configured on the same node as the ``WebServer`` if it is intended to send a users request i.e. + in the case that the :ref:`WebBrowser` sends a request with users in its request path, the ``WebServer`` will utilise the ``DatabaseClient`` to send a request to the ``DatabaseService`` + +Implementation +============== + +- HTTP request uses a ``HttpRequestPacket`` object +- HTTP response uses a ``HttpResponsePacket`` object +- Extends Service class for integration with ``SoftwareManager``. + + +Examples +======== + +Python +"""""" + +.. code-block:: python + + from primaite.simulator.network.hardware.nodes.host.server import Server + from primaite.simulator.system.services.web_server.web_server import WebServer + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.2.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Install WebServer on server + server.software_manager.install(WebServer) + web_server: WebServer = server.software_manager.software.get("WebServer") + web_server.start() + +Via Configuration +""""""""""""""""" + +.. code-block:: yaml + + simulation: + network: + nodes: + - ref: example_server + hostname: example_server + type: server + ... + services: + - ref: web_server + type: WebServer + +Configuration +============= + +.. include:: ../common/common_configuration.rst + +.. |SOFTWARE_NAME| replace:: WebServer +.. |SOFTWARE_NAME_BACKTICK| replace:: ``WebServer`` + +**WebServer has no configuration options** diff --git a/docs/source/simulation_components/system/session_and_software_manager.rst b/docs/source/simulation_components/system/session_and_software_manager.rst index a550faf1..8af96e87 100644 --- a/docs/source/simulation_components/system/session_and_software_manager.rst +++ b/docs/source/simulation_components/system/session_and_software_manager.rst @@ -16,6 +16,8 @@ ARP, ICMP, or the Web Client. This pathway exemplifies the structured processing each frame reaches its intended target within the simulated environment. .. image:: node_session_software_model_example.png + :width: 500 + :align: center Session Manager --------------- diff --git a/docs/source/simulation_components/system/software.rst b/docs/source/simulation_components/system/software.rst index cd6b0aa3..2ba8e841 100644 --- a/docs/source/simulation_components/system/software.rst +++ b/docs/source/simulation_components/system/software.rst @@ -10,7 +10,7 @@ Software Base Software ------------- -All software which inherits ``IOSoftware`` installed on a node will not work unless the node has been turned on. +Software which inherits ``IOSoftware`` installed on a node will not work unless the node has been turned on. See :ref:`Node Start up and Shut down` @@ -39,15 +39,27 @@ See :ref:`Node Start up and Shut down` assert node.operating_state is NodeOperatingState.ON assert web_server.operating_state is ServiceOperatingState.RUNNING # service turned back on when node is powered on +.. _List of Applications: -Services, Processes and Applications: -##################################### +Applications +############ -.. toctree:: - :maxdepth: 2 +These are a list of applications that are currently available in PrimAITE: - database_client_server - data_manipulation_bot - dns_client_server - ftp_client_server - web_browser_and_web_server_service +.. include:: list_of_applications.rst + +.. _List of Services: + +Services +######## + +These are a list of services that are currently available in PrimAITE: + +.. include:: list_of_services.rst + +.. _List of Processes: + +Processes +######### + +`To be implemented` diff --git a/docs/source/simulation_components/system/web_browser_and_web_server_service.rst b/docs/source/simulation_components/system/web_browser_and_web_server_service.rst deleted file mode 100644 index 538baa58..00000000 --- a/docs/source/simulation_components/system/web_browser_and_web_server_service.rst +++ /dev/null @@ -1,110 +0,0 @@ -.. only:: comment - - © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK - -Web Browser and Web Server Service -================================== - -Web Server Service ------------------- -Provides a Web Server simulation by extending the base Service class. - -Key capabilities -^^^^^^^^^^^^^^^^ - -- Simulates a web server with the capability to also request data from a database -- Allows the emulation of HTTP requests between client (e.g. a web browser) and server - - GET request sends a get all users request to the database server and returns an HTTP 200 status if the database is responsive -- Leverages the Service base class for install/uninstall, status tracking, etc. - -Usage -^^^^^ -- Install on a Node via the ``SoftwareManager`` to start the `WebServer`. -- Service runs on HTTP port 80 by default. (TODO: HTTPS) - -Implementation -^^^^^^^^^^^^^^ - -- HTTP request uses a ``HttpRequestPacket`` object -- HTTP response uses a ``HttpResponsePacket`` object -- Extends Service class for integration with ``SoftwareManager``. - -Web Browser (Web Client) ------------------------- - -The ``WebBrowser`` provides a client interface for connecting to the ``WebServer``. - -Key features -^^^^^^^^^^^^ - -- Connects to the ``WebServer`` via the ``SoftwareManager``. -- Simulates HTTP requests and HTTP packet transfer across a network -- Allows the emulation of HTTP requests between client and server: - - Automatically uses ``DNSClient`` to resolve domain names - - GET: performs an HTTP GET request from client to server -- Leverages the Service base class for install/uninstall, status tracking, etc. - -Usage -^^^^^ - -- Install on a Node via the ``SoftwareManager`` to start the ``WebBrowser``. -- Service runs on HTTP port 80 by default. (TODO: HTTPS) -- Execute sending an HTTP GET request with ``get_webpage`` - -Implementation -^^^^^^^^^^^^^^ - -- Leverages ``SoftwareManager`` for sending payloads over the network. -- Provides easy interface for making HTTP requests between an HTTP client and server. -- Extends base Service class. - - -Example Usage -------------- - -Dependencies -^^^^^^^^^^^^ - -.. code-block:: python - - from primaite.simulator.network.container import Network - from primaite.simulator.network.hardware.nodes.computer import Computer - from primaite.simulator.network.hardware.nodes.server import Server - from primaite.simulator.system.applications.web_browser import WebBrowser - from primaite.simulator.system.services.web_server.web_server_service import WebServer - -Example peer to peer network -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: python - - net = Network() - - pc1 = Computer(hostname="pc1", ip_address="192.168.1.50", subnet_mask="255.255.255.0") - srv = Server(hostname="srv", ip_address="192.168.1.10", subnet_mask="255.255.255.0") - pc1.power_on() - srv.power_on() - net.connect(pc1.network_interface[1], srv.network_interface[1]) - -Install the Web Server -^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: python - - # web browser is automatically installed in computer nodes - # IRL this is usually included with an OS - client: WebBrowser = pc1.software_manager.software['WebBrowser'] - - # install web server - srv.software_manager.install(WebServer) - webserv: WebServer = srv.software_manager.software['WebServer'] - -Open the web page -^^^^^^^^^^^^^^^^^ - -Using a domain name to connect to a website requires setting up DNS Servers. For this example, it is possible to use the IP address directly - -.. code-block:: python - - # check that the get request succeeded - print(client.get_webpage("http://192.168.1.10")) # should be True diff --git a/docs/source/simulation_structure.rst b/docs/source/simulation_structure.rst index 6e0ab5ce..f9a69b26 100644 --- a/docs/source/simulation_structure.rst +++ b/docs/source/simulation_structure.rst @@ -12,14 +12,15 @@ and a domain controller for managing software and users. Each node of the simulation 'tree' has responsibility for creating, deleting, and updating its direct descendants. Also, when a component's ``describe_state()`` method is called, it will include the state of its descendants. The -``apply_request()`` method can be used to act on a component or one of its descendatnts. The diagram below shows the +``apply_request()`` method can be used to act on a component or one of its descendants. The diagram below shows the relationship between components. -.. image:: _static/component_relationship.png +.. image:: ../../_static/component_relationship.png :width: 500 - :alt: The top level simulation object owns a NetworkContainer and a DomainController. The DomainController has a - list of accounts. The network container has links and nodes. Nodes can own switchports, NICs, FileSystem, - Application, Service, and Process. + :align: center + :alt: :: The top level simulation object owns a NetworkContainer and a DomainController. The DomainController has a + list of accounts. The network container has links and nodes. Nodes can own switchports, NICs, FileSystem, + Application, Service, and Process. Actions diff --git a/docs/source/state_system.rst b/docs/source/state_system.rst index 860c9827..5fc12c23 100644 --- a/docs/source/state_system.rst +++ b/docs/source/state_system.rst @@ -3,11 +3,11 @@ © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK Simulation State -============== +================ -``SimComponent`` objects in the simulation have a method called ``describe_state`` which return a dictionary of the state of the component. This is used to report pertinent data that could impact an agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` object reports not only its own attributes in the state but also those of its child components. I.e. a computer node will report the state of its ``FileSystem`` and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling the childrens' own ``describe_state`` methods. +``SimComponent`` objects in the simulation have a method called ``describe_state`` which return a dictionary of the state of the component. This is used to report pertinent data that could impact an agent's actions or rewards. For instance, the name and health status of a node is reported, which can be used by a reward function to punish corrupted or compromised nodes and reward healthy nodes. Each ``SimComponent`` object reports not only its own attributes in the state but also those of its child components. I.e. a computer node will report the state of its ``FileSystem`` and the ``FileSystem`` will report the state of its files and folders. This happens by recursively calling the children's own ``describe_state`` methods. -The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then passes the state to the agents once per simulation step. For this reason, all ``SimComponent`` objetcs must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``. +The game layer calls ``describe_state`` on the trunk ``SimComponent`` (the top-level parent) and then passes the state to the agents once per simulation step. For this reason, all ``SimComponent`` objects must have a ``describe_state`` method, and they must all be linked to the trunk ``SimComponent``. This code snippet demonstrates how the state information is defined within the ``SimComponent`` class: diff --git a/pyproject.toml b/pyproject.toml index 3e5b959a..19b5b7fa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ dev = [ "build==0.10.0", "flake8==6.0.0", "flake8-annotations", - "furo==2023.3.27", + "furo==2024.01.29", "gputil==1.4.0", "pip-licenses==4.3.0", "pre-commit==2.20.0", @@ -67,7 +67,7 @@ dev = [ "pytest-cov==4.0.0", "pytest-flake8==1.1.1", "setuptools==66", - "Sphinx==6.1.3", + "Sphinx==7.1.2", "sphinx-copybutton==0.5.2", "wheel==0.38.4" ] diff --git a/src/primaite/cli.py b/src/primaite/cli.py index 81ab2792..18d21f7b 100644 --- a/src/primaite/cli.py +++ b/src/primaite/cli.py @@ -127,10 +127,10 @@ def session( :param config: The path to the config file. Optional, if None, the example config will be used. :type config: Optional[str] """ - from primaite.config.load import example_config_path + from primaite.config.load import data_manipulation_config_path from primaite.main import run if not config: - config = example_config_path() + config = data_manipulation_config_path() print(config) run(config_path=config, agent_load_path=agent_load_file) diff --git a/src/primaite/config/_package_data/data_manipulation.yaml b/src/primaite/config/_package_data/data_manipulation.yaml new file mode 100644 index 00000000..c561030a --- /dev/null +++ b/src/primaite/config/_package_data/data_manipulation.yaml @@ -0,0 +1,960 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: + type: UC2RedObservation + options: + nodes: {} + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: # options specific to this particular agent type, basically args of __init__(self) + start_settings: + start_step: 25 + frequency: 20 + variance: 5 + + - ref: defender + team: BLUE + type: ProxyAgent + + observation_space: + type: UC2BlueObservation + options: + num_services_per_node: 1 + num_folders_per_node: 1 + num_files_per_folder: 1 + num_nics_per_node: 2 + nodes: + - node_hostname: domain_controller + services: + - service_name: DNSServer + - node_hostname: web_server + services: + - service_name: WebServer + - node_hostname: database_server + folders: + - folder_name: database + files: + - file_name: database.db + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 + links: + - link_ref: router_1___switch_1 + - link_ref: router_1___switch_2 + - link_ref: switch_1___domain_controller + - link_ref: switch_1___web_server + - link_ref: switch_1___database_server + - link_ref: switch_1___backup_server + - link_ref: switch_1___security_suite + - link_ref: switch_2___client_1 + - link_ref: switch_2___client_2 + - link_ref: switch_2___security_suite + acl: + options: + max_acl_rules: 10 + router_hostname: router_1 + ip_address_order: + - node_hostname: domain_controller + nic_num: 1 + - node_hostname: web_server + nic_num: 1 + - node_hostname: database_server + nic_num: 1 + - node_hostname: backup_server + nic_num: 1 + - node_hostname: security_suite + nic_num: 1 + - node_hostname: client_1 + nic_num: 1 + - node_hostname: client_2 + nic_num: 1 + - node_hostname: security_suite + nic_num: 2 + ics: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SERVICE_SCAN + - type: NODE_SERVICE_STOP + - type: NODE_SERVICE_START + - type: NODE_SERVICE_PAUSE + - type: NODE_SERVICE_RESUME + - type: NODE_SERVICE_RESTART + - type: NODE_SERVICE_DISABLE + - type: NODE_SERVICE_ENABLE + - type: NODE_SERVICE_PATCH + - type: NODE_FILE_SCAN + - type: NODE_FILE_CHECKHASH + - type: NODE_FILE_DELETE + - type: NODE_FILE_REPAIR + - type: NODE_FILE_RESTORE + - type: NODE_FOLDER_SCAN + - type: NODE_FOLDER_CHECKHASH + - type: NODE_FOLDER_REPAIR + - type: NODE_FOLDER_RESTORE + - type: NODE_OS_SCAN + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: NODE_RESET + - type: NETWORK_ACL_ADDRULE + options: + target_router_hostname: router_1 + - type: NETWORK_ACL_REMOVERULE + options: + target_router_hostname: router_1 + - type: NETWORK_NIC_ENABLE + - type: NETWORK_NIC_DISABLE + + action_map: + 0: + action: DONOTHING + options: {} + # scan webapp service + 1: + action: NODE_SERVICE_SCAN + options: + node_id: 1 + service_id: 0 + # stop webapp service + 2: + action: NODE_SERVICE_STOP + options: + node_id: 1 + service_id: 0 + # start webapp service + 3: + action: "NODE_SERVICE_START" + options: + node_id: 1 + service_id: 0 + 4: + action: "NODE_SERVICE_PAUSE" + options: + node_id: 1 + service_id: 0 + 5: + action: "NODE_SERVICE_RESUME" + options: + node_id: 1 + service_id: 0 + 6: + action: "NODE_SERVICE_RESTART" + options: + node_id: 1 + service_id: 0 + 7: + action: "NODE_SERVICE_DISABLE" + options: + node_id: 1 + service_id: 0 + 8: + action: "NODE_SERVICE_ENABLE" + options: + node_id: 1 + service_id: 0 + 9: # check database.db file + action: "NODE_FILE_SCAN" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 10: + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 11: + action: "NODE_FILE_DELETE" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 12: + action: "NODE_FILE_REPAIR" + options: + node_id: 2 + folder_id: 0 + file_id: 0 + 13: + action: "NODE_SERVICE_PATCH" + options: + node_id: 2 + service_id: 0 + 14: + action: "NODE_FOLDER_SCAN" + options: + node_id: 2 + folder_id: 0 + 15: + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. + options: + node_id: 2 + folder_id: 0 + 16: + action: "NODE_FOLDER_REPAIR" + options: + node_id: 2 + folder_id: 0 + 17: + action: "NODE_FOLDER_RESTORE" + options: + node_id: 2 + folder_id: 0 + 18: + action: "NODE_OS_SCAN" + options: + node_id: 0 + 19: + action: "NODE_SHUTDOWN" + options: + node_id: 0 + 20: + action: NODE_STARTUP + options: + node_id: 0 + 21: + action: NODE_RESET + options: + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" + action: "NETWORK_ACL_ADDRULE" + options: + position: 1 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" + action: "NETWORK_ACL_ADDRULE" + options: + position: 2 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "NETWORK_ACL_ADDRULE" + options: + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 49: # old action num: 25 # block tcp traffic from client 2 to web app + action: "NETWORK_ACL_ADDRULE" + options: + position: 4 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 50: # old action num: 26 + action: "NETWORK_ACL_ADDRULE" + options: + position: 5 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 51: # old action num: 27 + action: "NETWORK_ACL_ADDRULE" + options: + position: 6 + permission: 2 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database + source_port_id: 1 + dest_port_id: 1 + protocol_id: 3 + 52: # old action num: 28 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 0 + 53: # old action num: 29 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 1 + 54: # old action num: 30 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 2 + 55: # old action num: 31 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 3 + 56: # old action num: 32 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 4 + 57: # old action num: 33 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 5 + 58: # old action num: 34 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 6 + 59: # old action num: 35 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 7 + 60: # old action num: 36 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 8 + 61: # old action num: 37 + action: "NETWORK_ACL_REMOVERULE" + options: + position: 9 + 62: # old action num: 38 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 0 + nic_id: 0 + 63: # old action num: 39 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 0 + nic_id: 0 + 64: # old action num: 40 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 1 + nic_id: 0 + 65: # old action num: 41 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 1 + nic_id: 0 + 66: # old action num: 42 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 2 + nic_id: 0 + 67: # old action num: 43 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 2 + nic_id: 0 + 68: # old action num: 44 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 3 + nic_id: 0 + 69: # old action num: 45 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 3 + nic_id: 0 + 70: # old action num: 46 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 4 + nic_id: 1 + 73: # old action num: 49 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 4 + nic_id: 1 + 74: # old action num: 50 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 5 + nic_id: 0 + 75: # old action num: 51 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 5 + nic_id: 0 + 76: # old action num: 52 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 6 + nic_id: 0 + 77: # old action num: 53 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 6 + nic_id: 0 + + + + options: + nodes: + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient + services: + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db + services: + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + + max_folders_per_node: 2 + max_files_per_folder: 2 + max_services_per_node: 2 + max_nics_per_node: 8 + max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user + + + + agent_settings: + flatten_obs: true + + + + + +simulation: + network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE + nodes: + + - ref: router_1 + hostname: router_1 + type: router + num_ports: 5 + ports: + 1: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + 18: + action: PERMIT + src_port: POSTGRES_SERVER + dst_port: POSTGRES_SERVER + 19: + action: PERMIT + src_port: DNS + dst_port: DNS + 20: + action: PERMIT + src_port: FTP + dst_port: FTP + 21: + action: PERMIT + src_port: HTTP + dst_port: HTTP + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + + - ref: switch_1 + hostname: switch_1 + type: switch + num_ports: 8 + + - ref: switch_2 + hostname: switch_2 + type: switch + num_ports: 8 + + - ref: domain_controller + hostname: domain_controller + type: server + ip_address: 192.168.1.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - ref: domain_controller_dns_server + type: DNSServer + options: + domain_mapping: + arcd.com: 192.168.1.12 # web server + + - ref: web_server + hostname: web_server + type: server + ip_address: 192.168.1.12 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: web_server_web_service + type: WebServer + applications: + - ref: web_server_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + + + - ref: database_server + hostname: database_server + type: server + ip_address: 192.168.1.14 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: database_service + type: DatabaseService + options: + backup_server_ip: 192.168.1.16 + - ref: database_ftp_client + type: FTPClient + + - ref: backup_server + hostname: backup_server + type: server + ip_address: 192.168.1.16 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + services: + - ref: backup_service + type: FTPServer + + - ref: security_suite + hostname: security_suite + type: server + ip_address: 192.168.1.110 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + dns_server: 192.168.1.10 + network_interfaces: + 2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot + ip_address: 192.168.10.110 + subnet_mask: 255.255.255.0 + + - ref: client_1 + hostname: client_1 + type: computer + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - ref: client_1_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - ref: client_1_dns_client + type: DNSClient + + - ref: client_2 + hostname: client_2 + type: computer + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + applications: + - ref: client_2_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - ref: client_2_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 + services: + - ref: client_2_dns_client + type: DNSClient + + + + links: + - ref: router_1___switch_1 + endpoint_a_ref: router_1 + endpoint_a_port: 1 + endpoint_b_ref: switch_1 + endpoint_b_port: 8 + - ref: router_1___switch_2 + endpoint_a_ref: router_1 + endpoint_a_port: 2 + endpoint_b_ref: switch_2 + endpoint_b_port: 8 + - ref: switch_1___domain_controller + endpoint_a_ref: switch_1 + endpoint_a_port: 1 + endpoint_b_ref: domain_controller + endpoint_b_port: 1 + - ref: switch_1___web_server + endpoint_a_ref: switch_1 + endpoint_a_port: 2 + endpoint_b_ref: web_server + endpoint_b_port: 1 + - ref: switch_1___database_server + endpoint_a_ref: switch_1 + endpoint_a_port: 3 + endpoint_b_ref: database_server + endpoint_b_port: 1 + - ref: switch_1___backup_server + endpoint_a_ref: switch_1 + endpoint_a_port: 4 + endpoint_b_ref: backup_server + endpoint_b_port: 1 + - ref: switch_1___security_suite + endpoint_a_ref: switch_1 + endpoint_a_port: 7 + endpoint_b_ref: security_suite + endpoint_b_port: 1 + - ref: switch_2___client_1 + endpoint_a_ref: switch_2 + endpoint_a_port: 1 + endpoint_b_ref: client_1 + endpoint_b_port: 1 + - ref: switch_2___client_2 + endpoint_a_ref: switch_2 + endpoint_a_port: 2 + endpoint_b_ref: client_2 + endpoint_b_port: 1 + - ref: switch_2___security_suite + endpoint_a_ref: switch_2 + endpoint_a_port: 7 + endpoint_b_ref: security_suite + endpoint_b_port: 2 diff --git a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml b/src/primaite/config/_package_data/data_manipulation_marl.yaml similarity index 55% rename from src/primaite/config/_package_data/example_config_2_rl_agents.yaml rename to src/primaite/config/_package_data/data_manipulation_marl.yaml index 3a6feb68..85d282ba 100644 --- a/src/primaite/config/_package_data/example_config_2_rl_agents.yaml +++ b/src/primaite/config/_package_data/data_manipulation_marl.yaml @@ -1,19 +1,26 @@ training_config: rl_framework: RLLIB_multi_agent - # rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false n_agents: 2 agent_references: - defender_1 - defender_2 + io_settings: - save_checkpoints: true - checkpoint_interval: 5 + save_agent_actions: true save_step_metadata: false + save_pcap_logs: false + save_sys_logs: true game: - max_episode_length: 256 + max_episode_length: 128 ports: - ARP - DNS @@ -27,7 +34,12 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -36,25 +48,95 @@ agents: - type: NODE_APPLICATION_EXECUTE options: nodes: - - node_ref: client_2 + - node_name: client_2 applications: - - application_ref: client_2_web_browser + - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + - application_name: DatabaseClient + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 - - ref: client_1_data_manipulation_red_bot + reward_function: + reward_components: + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 + + + + + + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -72,9 +154,12 @@ agents: - type: NODE_OS_SCAN options: nodes: - - node_ref: client_1 + - node_name: client_1 applications: - - application_ref: data_manipulation_bot + - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -101,25 +186,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -134,23 +215,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -181,10 +262,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -208,319 +289,441 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: - action: "NODE_FILE_CHECKHASH" + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_PATCH" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 15: - action: "NODE_FOLDER_CHECKHASH" + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 + node_id: 0 22: - action: "NETWORK_ACL_ADDRULE" + action: "NODE_OS_SCAN" options: - position: 1 - permission: 2 - source_ip_id: 7 - dest_ip_id: 1 - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 + node_id: 1 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "NETWORK_ACL_ADDRULE" + options: + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 49: # old action num: 25 # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 26: + 50: # old action num: 26 action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 27: + 51: # old action num: 27 action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 28: + 52: # old action num: 28 action: "NETWORK_ACL_REMOVERULE" options: position: 0 - 29: + 53: # old action num: 29 action: "NETWORK_ACL_REMOVERULE" options: position: 1 - 30: + 54: # old action num: 30 action: "NETWORK_ACL_REMOVERULE" options: position: 2 - 31: + 55: # old action num: 31 action: "NETWORK_ACL_REMOVERULE" options: position: 3 - 32: + 56: # old action num: 32 action: "NETWORK_ACL_REMOVERULE" options: position: 4 - 33: + 57: # old action num: 33 action: "NETWORK_ACL_REMOVERULE" options: position: 5 - 34: + 58: # old action num: 34 action: "NETWORK_ACL_REMOVERULE" options: position: 6 - 35: + 59: # old action num: 35 action: "NETWORK_ACL_REMOVERULE" options: position: 7 - 36: + 60: # old action num: 36 action: "NETWORK_ACL_REMOVERULE" options: position: 8 - 37: + 61: # old action num: 37 action: "NETWORK_ACL_REMOVERULE" options: position: 9 - 38: + 62: # old action num: 38 action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 - 39: + nic_id: 0 + 63: # old action num: 39 action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 - 40: + nic_id: 0 + 64: # old action num: 40 action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 - 41: + nic_id: 0 + 65: # old action num: 41 action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 - 42: + nic_id: 0 + 66: # old action num: 42 action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 - 43: + nic_id: 0 + 67: # old action num: 43 action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 - 44: + nic_id: 0 + 68: # old action num: 44 action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 - 45: + nic_id: 0 + 69: # old action num: 45 action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 - 46: + nic_id: 0 + 70: # old action num: 46 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 action: "NETWORK_NIC_DISABLE" options: node_id: 4 nic_id: 1 - 47: + 73: # old action num: 49 action: "NETWORK_NIC_ENABLE" options: node_id: 4 nic_id: 1 - 48: - action: "NETWORK_NIC_DISABLE" - options: - node_id: 4 - nic_id: 2 - 49: - action: "NETWORK_NIC_ENABLE" - options: - node_id: 4 - nic_id: 2 - 50: + 74: # old action num: 50 action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 - 51: + nic_id: 0 + 75: # old action num: 51 action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 - 52: + nic_id: 0 + 76: # old action num: 52 action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 - 53: + nic_id: 0 + 77: # old action num: 53 action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 + reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.40 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: SHARED_REWARD + weight: 1.0 options: - node_ref: web_server - service_ref: web_server_web_service + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user agent_settings: - # ... - + flatten_obs: true - ref: defender_2 team: BLUE @@ -534,25 +737,21 @@ agents: num_files_per_folder: 1 num_nics_per_node: 2 nodes: - - node_ref: domain_controller + - node_hostname: domain_controller services: - - service_ref: domain_controller_dns_server - - node_ref: web_server + - service_name: DNSServer + - node_hostname: web_server services: - - service_ref: web_server_database_client - - node_ref: database_server - services: - - service_ref: database_service + - service_name: WebServer + - node_hostname: database_server folders: - folder_name: database files: - file_name: database.db - - node_ref: backup_server - # services: - # - service_ref: backup_service - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - node_hostname: backup_server + - node_hostname: security_suite + - node_hostname: client_1 + - node_hostname: client_2 links: - link_ref: router_1___switch_1 - link_ref: router_1___switch_2 @@ -567,23 +766,23 @@ agents: acl: options: max_acl_rules: 10 - router_node_ref: router_1 + router_hostname: router_1 ip_address_order: - - node_ref: domain_controller + - node_hostname: domain_controller nic_num: 1 - - node_ref: web_server + - node_hostname: web_server nic_num: 1 - - node_ref: database_server + - node_hostname: database_server nic_num: 1 - - node_ref: backup_server + - node_hostname: backup_server nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 1 - - node_ref: client_1 + - node_hostname: client_1 nic_num: 1 - - node_ref: client_2 + - node_hostname: client_2 nic_num: 1 - - node_ref: security_suite + - node_hostname: security_suite nic_num: 2 ics: null @@ -614,10 +813,10 @@ agents: - type: NODE_RESET - type: NETWORK_ACL_ADDRULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_ACL_REMOVERULE options: - target_router_ref: router_1 + target_router_hostname: router_1 - type: NETWORK_NIC_ENABLE - type: NETWORK_NIC_DISABLE @@ -641,330 +840,456 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: - action: "NODE_FILE_CHECKHASH" + action: "NODE_FILE_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 1 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_PATCH" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 15: - action: "NODE_FOLDER_CHECKHASH" + action: "NODE_FOLDER_SCAN" # CHECKHASH replaced by SCAN - but the behaviour is the same in this context. options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 1 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 + node_id: 0 22: - action: "NETWORK_ACL_ADDRULE" + action: "NODE_OS_SCAN" options: - position: 1 - permission: 2 - source_ip_id: 7 - dest_ip_id: 1 - source_port_id: 1 - dest_port_id: 1 - protocol_id: 1 + node_id: 1 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 permission: 2 - source_ip_id: 8 - dest_ip_id: 1 + source_ip_id: 7 # client 1 + dest_ip_id: 1 # ALL source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 2 permission: 2 - source_ip_id: 7 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 1 # ALL + source_port_id: 1 + dest_port_id: 1 + protocol_id: 1 + 48: # old action num: 24 # block tcp traffic from client 1 to web app + action: "NETWORK_ACL_ADDRULE" + options: + position: 3 + permission: 2 + source_ip_id: 7 # client 1 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: + 49: # old action num: 25 # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 4 permission: 2 - source_ip_id: 8 - dest_ip_id: 3 + source_ip_id: 8 # client 2 + dest_ip_id: 3 # web server source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 26: + 50: # old action num: 26 action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 5 permission: 2 - source_ip_id: 7 - dest_ip_id: 4 + source_ip_id: 7 # client 1 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 27: + 51: # old action num: 27 action: "NETWORK_ACL_ADDRULE" options: - position: 1 + position: 6 permission: 2 - source_ip_id: 8 - dest_ip_id: 4 + source_ip_id: 8 # client 2 + dest_ip_id: 4 # database source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 28: + 52: # old action num: 28 action: "NETWORK_ACL_REMOVERULE" options: position: 0 - 29: + 53: # old action num: 29 action: "NETWORK_ACL_REMOVERULE" options: position: 1 - 30: + 54: # old action num: 30 action: "NETWORK_ACL_REMOVERULE" options: position: 2 - 31: + 55: # old action num: 31 action: "NETWORK_ACL_REMOVERULE" options: position: 3 - 32: + 56: # old action num: 32 action: "NETWORK_ACL_REMOVERULE" options: position: 4 - 33: + 57: # old action num: 33 action: "NETWORK_ACL_REMOVERULE" options: position: 5 - 34: + 58: # old action num: 34 action: "NETWORK_ACL_REMOVERULE" options: position: 6 - 35: + 59: # old action num: 35 action: "NETWORK_ACL_REMOVERULE" options: position: 7 - 36: + 60: # old action num: 36 action: "NETWORK_ACL_REMOVERULE" options: position: 8 - 37: + 61: # old action num: 37 action: "NETWORK_ACL_REMOVERULE" options: position: 9 - 38: + 62: # old action num: 38 action: "NETWORK_NIC_DISABLE" options: node_id: 0 - nic_id: 1 - 39: + nic_id: 0 + 63: # old action num: 39 action: "NETWORK_NIC_ENABLE" options: node_id: 0 - nic_id: 1 - 40: + nic_id: 0 + 64: # old action num: 40 action: "NETWORK_NIC_DISABLE" options: node_id: 1 - nic_id: 1 - 41: + nic_id: 0 + 65: # old action num: 41 action: "NETWORK_NIC_ENABLE" options: node_id: 1 - nic_id: 1 - 42: + nic_id: 0 + 66: # old action num: 42 action: "NETWORK_NIC_DISABLE" options: node_id: 2 - nic_id: 1 - 43: + nic_id: 0 + 67: # old action num: 43 action: "NETWORK_NIC_ENABLE" options: node_id: 2 - nic_id: 1 - 44: + nic_id: 0 + 68: # old action num: 44 action: "NETWORK_NIC_DISABLE" options: node_id: 3 - nic_id: 1 - 45: + nic_id: 0 + 69: # old action num: 45 action: "NETWORK_NIC_ENABLE" options: node_id: 3 - nic_id: 1 - 46: + nic_id: 0 + 70: # old action num: 46 + action: "NETWORK_NIC_DISABLE" + options: + node_id: 4 + nic_id: 0 + 71: # old action num: 47 + action: "NETWORK_NIC_ENABLE" + options: + node_id: 4 + nic_id: 0 + 72: # old action num: 48 action: "NETWORK_NIC_DISABLE" options: node_id: 4 nic_id: 1 - 47: + 73: # old action num: 49 action: "NETWORK_NIC_ENABLE" options: node_id: 4 nic_id: 1 - 48: - action: "NETWORK_NIC_DISABLE" - options: - node_id: 4 - nic_id: 2 - 49: - action: "NETWORK_NIC_ENABLE" - options: - node_id: 4 - nic_id: 2 - 50: + 74: # old action num: 50 action: "NETWORK_NIC_DISABLE" options: node_id: 5 - nic_id: 1 - 51: + nic_id: 0 + 75: # old action num: 51 action: "NETWORK_NIC_ENABLE" options: node_id: 5 - nic_id: 1 - 52: + nic_id: 0 + 76: # old action num: 52 action: "NETWORK_NIC_DISABLE" options: node_id: 6 - nic_id: 1 - 53: + nic_id: 0 + 77: # old action num: 53 action: "NETWORK_NIC_ENABLE" options: node_id: 6 - nic_id: 1 + nic_id: 0 + options: nodes: - - node_ref: domain_controller - - node_ref: web_server + - node_name: domain_controller + - node_name: web_server + applications: + - application_name: DatabaseClient services: - - service_ref: web_server_web_service - - node_ref: database_server + - service_name: WebServer + - node_name: database_server + folders: + - folder_name: database + files: + - file_name: database.db services: - - service_ref: database_service - - node_ref: backup_server - - node_ref: security_suite - - node_ref: client_1 - - node_ref: client_2 + - service_name: DatabaseService + - node_name: backup_server + - node_name: security_suite + - node_name: client_1 + - node_name: client_2 + max_folders_per_node: 2 max_files_per_folder: 2 max_services_per_node: 2 max_nics_per_node: 8 max_acl_rules: 10 + ip_address_order: + - node_name: domain_controller + nic_num: 1 + - node_name: web_server + nic_num: 1 + - node_name: database_server + nic_num: 1 + - node_name: backup_server + nic_num: 1 + - node_name: security_suite + nic_num: 1 + - node_name: client_1 + nic_num: 1 + - node_name: client_2 + nic_num: 1 + - node_name: security_suite + nic_num: 2 reward_function: reward_components: - type: DATABASE_FILE_INTEGRITY - weight: 0.5 + weight: 0.40 options: - node_ref: database_server + node_hostname: database_server folder_name: database file_name: database.db - - - - type: WEB_SERVER_404_PENALTY - weight: 0.5 + - type: SHARED_REWARD + weight: 1.0 options: - node_ref: web_server - service_ref: web_server_web_service + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 + options: + agent_name: client_2_green_user agent_settings: - # ... - + flatten_obs: true simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - ref: router_1 - type: router hostname: router_1 + type: router num_ports: 5 ports: 1: @@ -999,18 +1324,18 @@ simulation: protocol: ICMP - ref: switch_1 - type: switch hostname: switch_1 + type: switch num_ports: 8 - ref: switch_2 - type: switch hostname: switch_2 + type: switch num_ports: 8 - ref: domain_controller - type: server hostname: domain_controller + type: server ip_address: 192.168.1.10 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1022,24 +1347,25 @@ simulation: arcd.com: 192.168.1.12 # web server - ref: web_server - type: server hostname: web_server + type: server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server - type: server hostname: database_server + type: server ip_address: 192.168.1.14 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1053,8 +1379,8 @@ simulation: type: FTPClient - ref: backup_server - type: server hostname: backup_server + type: server ip_address: 192.168.1.16 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1064,8 +1390,8 @@ simulation: type: FTPServer - ref: security_suite - type: server hostname: security_suite + type: server ip_address: 192.168.1.110 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -1076,8 +1402,8 @@ simulation: subnet_mask: 255.255.255.0 - ref: client_1 - type: computer hostname: client_1 + type: computer ip_address: 192.168.10.21 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -1086,17 +1412,25 @@ simulation: - ref: data_manipulation_bot type: DataManipulationBot options: - port_scan_p_of_success: 0.1 - data_manipulation_p_of_success: 0.1 + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 payload: "DELETE" server_ip: 192.168.1.14 + - ref: client_1_web_browser + type: WebBrowser + options: + target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_1_dns_client type: DNSClient - ref: client_2 - type: computer hostname: client_2 + type: computer ip_address: 192.168.10.22 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -1106,6 +1440,17 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - ref: client_2_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/src/primaite/config/load.py b/src/primaite/config/load.py index b01eb129..d5acd690 100644 --- a/src/primaite/config/load.py +++ b/src/primaite/config/load.py @@ -30,14 +30,14 @@ def load(file_path: Union[str, Path]) -> Dict: return config -def example_config_path() -> Path: +def data_manipulation_config_path() -> Path: """ Get the path to the example config. :return: Path to the example config. :rtype: Path """ - path = _EXAMPLE_CFG / "example_config.yaml" + path = _EXAMPLE_CFG / "data_manipulation.yaml" if not path.exists(): msg = f"Example config does not exist: {path}. Have you run `primaite setup`?" _LOGGER.error(msg) diff --git a/src/primaite/game/agent/actions.py b/src/primaite/game/agent/actions.py index 1793d420..4d28328e 100644 --- a/src/primaite/game/agent/actions.py +++ b/src/primaite/game/agent/actions.py @@ -492,9 +492,9 @@ class NetworkACLAddRuleAction(AbstractAction): "add_rule", permission_str, protocol, - src_ip, + str(src_ip), src_port, - dst_ip, + str(dst_ip), dst_port, position, ] @@ -572,7 +572,7 @@ class NetworkNICDisableAction(NetworkNICAbstractAction): class ActionManager: """Class which manages the action space for an agent.""" - _act_class_identifiers: Dict[str, type] = { + act_class_identifiers: Dict[str, type] = { "DONOTHING": DoNothingAction, "NODE_SERVICE_SCAN": NodeServiceScanAction, "NODE_SERVICE_STOP": NodeServiceStopAction, @@ -607,7 +607,6 @@ class ActionManager: def __init__( self, - game: "PrimaiteGame", # reference to game for information lookup actions: List[Dict], # stores list of actions available to agent nodes: List[Dict], # extra configuration for each node max_folders_per_node: int = 2, # allows calculating shape @@ -618,7 +617,7 @@ class ActionManager: max_acl_rules: int = 10, # allows calculating shape protocols: List[str] = ["TCP", "UDP", "ICMP"], # allow mapping index to protocol ports: List[str] = ["HTTP", "DNS", "ARP", "FTP", "NTP"], # allow mapping index to port - ip_address_list: Optional[List[str]] = None, # to allow us to map an index to an ip address. + ip_address_list: List[str] = [], # to allow us to map an index to an ip address. act_map: Optional[Dict[int, Dict]] = None, # allows restricting set of possible actions ) -> None: """Init method for ActionManager. @@ -649,7 +648,6 @@ class ActionManager: :param act_map: Action map which maps integers to actions. Used for restricting the set of possible actions. :type act_map: Optional[Dict[int, Dict]] """ - self.game: "PrimaiteGame" = game self.node_names: List[str] = [n["node_name"] for n in nodes] """List of node names in this action space. The list order is the mapping between node index and node name.""" self.application_names: List[List[str]] = [] @@ -707,25 +705,7 @@ class ActionManager: self.protocols: List[str] = protocols self.ports: List[str] = ports - self.ip_address_list: List[str] - - # If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from - # the nodes in the simulation. - # TODO: refactor. Options: - # 1: This should be pulled out into it's own function for clarity - # 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to - # go through the nodes here. - if ip_address_list is not None: - self.ip_address_list = ip_address_list - else: - self.ip_address_list = [] - for node_name in self.node_names: - node_obj = self.game.simulation.network.get_node_by_hostname(node_name) - if node_obj is None: - continue - network_interfaces = node_obj.network_interfaces - for nic_uuid, nic_obj in network_interfaces.items(): - self.ip_address_list.append(nic_obj.ip_address) + self.ip_address_list: List[str] = ip_address_list # action_args are settings which are applied to the action space as a whole. global_action_args = { @@ -753,7 +733,7 @@ class ActionManager: # and `options` is an optional dict of options to pass to the init method of the action class act_type = act_spec.get("type") act_options = act_spec.get("options", {}) - self.actions[act_type] = self._act_class_identifiers[act_type](self, **global_action_args, **act_options) + self.actions[act_type] = self.act_class_identifiers[act_type](self, **global_action_args, **act_options) self.action_map: Dict[int, Tuple[str, Dict]] = {} """ @@ -832,6 +812,13 @@ class ActionManager: :return: The node hostname. :rtype: str """ + if not node_idx < len(self.node_names): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx}, but its action space only" + f"has {len(self.node_names)} nodes." + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.node_names[node_idx] def get_folder_name_by_idx(self, node_idx: int, folder_idx: int) -> Optional[str]: @@ -845,6 +832,13 @@ class ActionManager: :return: The name of the folder. Or None if the node has fewer folders than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.folder_names) or folder_idx >= len(self.folder_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and folder {folder_idx}, but this" + f" is out of range for its action space. Folder on each node: {self.folder_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.folder_names[node_idx][folder_idx] def get_file_name_by_idx(self, node_idx: int, folder_idx: int, file_idx: int) -> Optional[str]: @@ -860,6 +854,17 @@ class ActionManager: fewer files than the given index. :rtype: Optional[str] """ + if ( + node_idx >= len(self.file_names) + or folder_idx >= len(self.file_names[node_idx]) + or file_idx >= len(self.file_names[node_idx][folder_idx]) + ): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} folder {folder_idx} file {file_idx}" + f" but this is out of range for its action space. Files on each node: {self.file_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.file_names[node_idx][folder_idx][file_idx] def get_service_name_by_idx(self, node_idx: int, service_idx: int) -> Optional[str]: @@ -872,6 +877,13 @@ class ActionManager: :return: The name of the service. Or None if the node has fewer services than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.service_names) or service_idx >= len(self.service_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and service {service_idx}, but this" + f" is out of range for its action space. Services on each node: {self.service_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.service_names[node_idx][service_idx] def get_application_name_by_idx(self, node_idx: int, application_idx: int) -> Optional[str]: @@ -884,6 +896,13 @@ class ActionManager: :return: The name of the service. Or None if the node has fewer services than the given index. :rtype: Optional[str] """ + if node_idx >= len(self.application_names) or application_idx >= len(self.application_names[node_idx]): + msg = ( + f"Error: agent attempted to perform an action on node {node_idx} and app {application_idx}, but " + f"this is out of range for its action space. Applications on each node: {self.application_names}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.application_names[node_idx][application_idx] def get_internet_protocol_by_idx(self, protocol_idx: int) -> str: @@ -894,6 +913,13 @@ class ActionManager: :return: The protocol. :rtype: str """ + if protocol_idx >= len(self.protocols): + msg = ( + f"Error: agent attempted to perform an action on protocol {protocol_idx} but this" + f" is out of range for its action space. Protocols: {self.protocols}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.protocols[protocol_idx] def get_ip_address_by_idx(self, ip_idx: int) -> str: @@ -905,6 +931,13 @@ class ActionManager: :return: The IP address. :rtype: str """ + if ip_idx >= len(self.ip_address_list): + msg = ( + f"Error: agent attempted to perform an action on ip address {ip_idx} but this" + f" is out of range for its action space. IP address list: {self.ip_address_list}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.ip_address_list[ip_idx] def get_port_by_idx(self, port_idx: int) -> str: @@ -916,6 +949,13 @@ class ActionManager: :return: The port. :rtype: str """ + if port_idx >= len(self.ports): + msg = ( + f"Error: agent attempted to perform an action on port {port_idx} but this" + f" is out of range for its action space. Port list: {self.ip_address_list}" + ) + _LOGGER.error(msg) + raise RuntimeError(msg) return self.ports[port_idx] def get_nic_num_by_idx(self, node_idx: int, nic_idx: int) -> int: @@ -958,6 +998,12 @@ class ActionManager: :return: The constructed ActionManager. :rtype: ActionManager """ + # If the user has provided a list of IP addresses, use that. Otherwise, generate a list of IP addresses from + # the nodes in the simulation. + # TODO: refactor. Options: + # 1: This should be pulled out into it's own function for clarity + # 2: The simulation itself should be able to provide a list of IP addresses with its API, rather than having to + # go through the nodes here. ip_address_order = cfg["options"].pop("ip_address_order", {}) ip_address_list = [] for entry in ip_address_order: @@ -967,13 +1013,22 @@ class ActionManager: ip_address = node_obj.network_interface[nic_num].ip_address ip_address_list.append(ip_address) + if not ip_address_list: + node_names = [n["node_name"] for n in cfg.get("nodes", {})] + for node_name in node_names: + node_obj = game.simulation.network.get_node_by_hostname(node_name) + if node_obj is None: + continue + network_interfaces = node_obj.network_interfaces + for nic_uuid, nic_obj in network_interfaces.items(): + ip_address_list.append(nic_obj.ip_address) + obj = cls( - game=game, actions=cfg["action_list"], **cfg["options"], protocols=game.options.protocols, ports=game.options.ports, - ip_address_list=ip_address_list or None, + ip_address_list=ip_address_list, act_map=cfg.get("action_map"), ) diff --git a/src/primaite/game/agent/data_manipulation_bot.py b/src/primaite/game/agent/data_manipulation_bot.py deleted file mode 100644 index 58b790ec..00000000 --- a/src/primaite/game/agent/data_manipulation_bot.py +++ /dev/null @@ -1,52 +0,0 @@ -import random -from typing import Dict, List, Tuple - -from gymnasium.core import ObsType - -from primaite.game.agent.interface import AbstractScriptedAgent -from primaite.simulator.system.applications.red_applications.data_manipulation_bot import DataManipulationBot - - -class DataManipulationAgent(AbstractScriptedAgent): - """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" - - data_manipulation_bots: List["DataManipulationBot"] = [] - next_execution_timestep: int = 0 - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) - - def _set_next_execution_timestep(self, timestep: int) -> None: - """Set the next execution timestep with a configured random variance. - - :param timestep: The timestep to add variance to. - """ - random_timestep_increment = random.randint( - -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance - ) - self.next_execution_timestep = timestep + random_timestep_increment - - def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. - - :param obs: _description_ - :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ - :rtype: Tuple[str, Dict] - """ - current_timestep = self.action_manager.game.step_counter - - if current_timestep < self.next_execution_timestep: - return "DONOTHING", {"dummy": 0} - - self._set_next_execution_timestep(current_timestep + self.agent_settings.start_settings.frequency) - - return "NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0} - - def reset_agent_for_episode(self) -> None: - """Set the next execution timestep when the episode resets.""" - super().reset_agent_for_episode() - self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) diff --git a/src/primaite/game/agent/interface.py b/src/primaite/game/agent/interface.py index 276715f7..cd4a1c29 100644 --- a/src/primaite/game/agent/interface.py +++ b/src/primaite/game/agent/interface.py @@ -1,18 +1,38 @@ """Interface for agents.""" from abc import ABC, abstractmethod -from typing import Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from gymnasium.core import ActType, ObsType from pydantic import BaseModel, model_validator from primaite.game.agent.actions import ActionManager -from primaite.game.agent.observations import ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager from primaite.game.agent.rewards import RewardFunction +from primaite.interface.request import RequestFormat, RequestResponse if TYPE_CHECKING: pass +class AgentActionHistoryItem(BaseModel): + """One entry of an agent's action log - what the agent did and how the simulator responded in 1 step.""" + + timestep: int + """Timestep of this action.""" + + action: str + """CAOS Action name.""" + + parameters: Dict[str, Any] + """CAOS parameters for the given action.""" + + request: RequestFormat + """The request that was sent to the simulation based on the CAOS action chosen.""" + + response: RequestResponse + """The response sent back by the simulator for this action.""" + + class AgentStartSettings(BaseModel): """Configuration values for when an agent starts performing actions.""" @@ -90,6 +110,7 @@ class AbstractAgent(ABC): self.observation_manager: Optional[ObservationManager] = observation_space self.reward_function: Optional[RewardFunction] = reward_function self.agent_settings = agent_settings or AgentSettings() + self.action_history: List[AgentActionHistoryItem] = [] def update_observation(self, state: Dict) -> ObsType: """ @@ -109,10 +130,10 @@ class AbstractAgent(ABC): :return: Reward from the state. :rtype: float """ - return self.reward_function.update(state) + return self.reward_function.update(state=state, last_action_response=self.action_history[-1]) @abstractmethod - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return an action to be taken in the environment. @@ -120,8 +141,8 @@ class AbstractAgent(ABC): :param obs: Observation of the environment. :type obs: ObsType - :param reward: Reward from the previous action, defaults to None TODO: should this parameter even be accepted? - :type reward: float, optional + :param timestep: The current timestep in the simulation, used for non-RL agents. Optional + :type timestep: int :return: Action to be taken in the environment. :rtype: Tuple[str, Dict] """ @@ -136,31 +157,24 @@ class AbstractAgent(ABC): request = self.action_manager.form_request(action_identifier=action, action_options=options) return request - def reset_agent_for_episode(self) -> None: - """Agent reset logic should go here.""" - pass + def process_action_response( + self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse + ) -> None: + """Process the response from the most recent action.""" + self.action_history.append( + AgentActionHistoryItem( + timestep=timestep, action=action, parameters=parameters, request=request, response=response + ) + ) class AbstractScriptedAgent(AbstractAgent): """Base class for actors which generate their own behaviour.""" - ... - - -class RandomAgent(AbstractScriptedAgent): - """Agent that ignores its observation and acts completely at random.""" - - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: - """Randomly sample an action from the action space. - - :param obs: _description_ - :type obs: ObsType - :param reward: _description_, defaults to None - :type reward: float, optional - :return: _description_ - :rtype: Tuple[str, Dict] - """ - return self.action_manager.get_action(self.action_manager.space.sample()) + @abstractmethod + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """Return an action to be taken in the environment.""" + return super().get_action(obs=obs, timestep=timestep) class ProxyAgent(AbstractAgent): @@ -183,14 +197,14 @@ class ProxyAgent(AbstractAgent): self.most_recent_action: ActType self.flatten_obs: bool = agent_settings.flatten_obs if agent_settings else False - def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: """ Return the agent's most recent action, formatted in CAOS format. :param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface. :type obs: ObsType - :param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None. - :type reward: float, optional + :param timestep: Current simulation timestep. Not used by ProxyAgents, bur required for the interface. + :type timestep: int :return: Action to be taken in CAOS format. :rtype: Tuple[str, Dict] """ diff --git a/src/primaite/game/agent/observations.py b/src/primaite/game/agent/observations.py deleted file mode 100644 index dfee2543..00000000 --- a/src/primaite/game/agent/observations.py +++ /dev/null @@ -1,1010 +0,0 @@ -"""Manages the observation space for the agent.""" -from abc import ABC, abstractmethod -from ipaddress import IPv4Address -from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING - -from gymnasium import spaces -from gymnasium.core import ObsType - -from primaite import getLogger -from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE - -_LOGGER = getLogger(__name__) - -if TYPE_CHECKING: - from primaite.game.game import PrimaiteGame - - -class AbstractObservation(ABC): - """Abstract class for an observation space component.""" - - @abstractmethod - def observe(self, state: Dict) -> Any: - """ - Return an observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Any - """ - pass - - @property - @abstractmethod - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space.""" - pass - - @classmethod - @abstractmethod - def from_config(cls, config: Dict, game: "PrimaiteGame"): - """Create this observation space component form a serialised format. - - The `game` parameter is for a the PrimaiteGame object that spawns this component. - """ - pass - - -class FileObservation(AbstractObservation): - """Observation of a file on a node in the network.""" - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """ - Initialise file observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',,'files',] - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.default_observation: spaces.Space = {"health_status": 0} - "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - file_state = access_from_nested_dict(state, self.where) - if file_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return {"health_status": file_state["visible_status"]} - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"health_status": spaces.Discrete(6)}) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": - """Create file observation from a config. - - :param config: Dictionary containing the configuration for this file observation. - :type config: Dict - :param game: _description_ - :type game: PrimaiteGame - :param parent_where: _description_, defaults to None - :type parent_where: _type_, optional - :return: _description_ - :rtype: _type_ - """ - return cls(where=parent_where + ["files", config["file_name"]]) - - -class ServiceObservation(AbstractObservation): - """Observation of a service in the network.""" - - default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} - "Default observation is what should be returned when the service doesn't exist." - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise service observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'services', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - service_state = access_from_nested_dict(state, self.where) - if service_state is NOT_PRESENT_IN_STATE: - return self.default_observation - return { - "operating_status": service_state["operating_state"], - "health_status": service_state["health_state_visible"], - } - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(6)}) - - @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None - ) -> "ServiceObservation": - """Create service observation from a config. - - :param config: Dictionary containing the configuration for this service observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. - :type parent_where: Optional[List[str]], optional - :return: Constructed service observation - :rtype: ServiceObservation - """ - return cls(where=parent_where + ["services", config["service_name"]]) - - -class LinkObservation(AbstractObservation): - """Observation of a link in the network.""" - - default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} - "Default observation is what should be returned when the link doesn't exist." - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise link observation. - - :param where: Store information about where in the simulation state dictionary to find the relevant information. - Optional. If None, this corresponds that the file does not exist and the observation will be populated with - zeroes. - - A typical location for a service looks like this: - `['network','nodes',,'servics', ]` - :type where: Optional[List[str]] - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - link_state = access_from_nested_dict(state, self.where) - if link_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - bandwidth = link_state["bandwidth"] - load = link_state["current_load"] - if load == 0: - utilisation_category = 0 - else: - utilisation_fraction = load / bandwidth - # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% - utilisation_category = int(utilisation_fraction * 9) + 1 - - # TODO: once the links support separte load per protocol, this needs amendment to reflect that. - return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": - """Create link observation from a config. - - :param config: Dictionary containing the configuration for this link observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed link observation - :rtype: LinkObservation - """ - return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) - - -class FolderObservation(AbstractObservation): - """Folder observation, including files inside of the folder.""" - - def __init__( - self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 - ) -> None: - """Initialise folder Observation, including files inside of the folder. - - :param where: Where in the simulation state dictionary to find the relevant information for this folder. - A typical location for a file looks like this: - ['network','nodes',,'file_system', 'folders',] - :type where: Optional[List[str]] - :param max_files: As size of the space must remain static, define max files that can be in this folder - , defaults to 5 - :type max_files: int, optional - :param file_positions: Defines the positioning within the observation space of particular files. This ensures - that even if new files are created, the existing files will always occupy the same space in the observation - space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the - observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same - name, it will take the position defined in this dict. Defaults to {} - :type file_positions: Dict[int, str], optional - """ - super().__init__() - - self.where: Optional[Tuple[str]] = where - - self.files: List[FileObservation] = files - while len(self.files) < num_files_per_folder: - self.files.append(FileObservation()) - while len(self.files) > num_files_per_folder: - truncated_file = self.files.pop() - msg = f"Too many files in folder observation. Truncating file {truncated_file}" - _LOGGER.warning(msg) - - self.default_observation = { - "health_status": 0, - "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - folder_state = access_from_nested_dict(state, self.where) - if folder_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - health_status = folder_state["health_status"] - - obs = {} - - obs["health_status"] = health_status - obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} - - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "health_status": spaces.Discrete(6), - "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), - } - ) - - @classmethod - def from_config( - cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 - ) -> "FolderObservation": - """Create folder observation from a config. Also creates child file observations. - - :param config: Dictionary containing the configuration for this folder observation. Includes the name of the - folder and the files inside of it. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this folder's - parent node. A typical location for a node ``where`` can be: - ['network','nodes',,'file_system'] - :type parent_where: Optional[List[str]] - :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed folder observation - :rtype: FolderObservation - """ - where = parent_where + ["folders", config["folder_name"]] - - file_configs = config["files"] - files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] - - return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) - - -class NicObservation(AbstractObservation): - """Observation of a Network Interface Card (NIC) in the network.""" - - default_observation: spaces.Space = {"nic_status": 0} - - def __init__(self, where: Optional[Tuple[str]] = None) -> None: - """Initialise NIC observation. - - :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical - example may look like this: - ['network','nodes',,'NICs',] - If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. - :type where: Optional[Tuple[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - nic_state = access_from_nested_dict(state, self.where) - if nic_state is NOT_PRESENT_IN_STATE: - return self.default_observation - else: - return {"nic_status": 1 if nic_state["enabled"] else 2} - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict({"nic_status": spaces.Discrete(3)}) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": - """Create NIC observation from a config. - - :param config: Dictionary containing the configuration for this NIC observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent - node. A typical location for a node ``where`` can be: ['network','nodes',] - :type parent_where: Optional[List[str]] - :return: Constructed NIC observation - :rtype: NicObservation - """ - return cls(where=parent_where + ["NICs", config["nic_num"]]) - - -class NodeObservation(AbstractObservation): - """Observation of a node in the network. Includes services, folders and NICs.""" - - def __init__( - self, - where: Optional[Tuple[str]] = None, - services: List[ServiceObservation] = [], - folders: List[FolderObservation] = [], - network_interfaces: List[NicObservation] = [], - logon_status: bool = False, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> None: - """ - Configurable observation for a node in the simulation. - - :param where: Where in the simulation state dictionary for find relevant information for this observation. - A typical location for a node looks like this: - ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] - :type where: List[str], optional - :param services: Mapping between position in observation space and service name, defaults to {} - :type services: Dict[int,str], optional - :param max_services: Max number of services that can be presented in observation space for this node - , defaults to 2 - :type max_services: int, optional - :param folders: Mapping between position in observation space and folder name, defaults to {} - :type folders: Dict[int,str], optional - :param max_folders: Max number of folders in this node's obs space, defaults to 2 - :type max_folders: int, optional - :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} - :type network_interfaces: Dict[int,str], optional - :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 - :type max_nics: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.services: List[ServiceObservation] = services - while len(self.services) < num_services_per_node: - # add empty service observation without `where` parameter so it always returns default (blank) observation - self.services.append(ServiceObservation()) - while len(self.services) > num_services_per_node: - truncated_service = self.services.pop() - msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" - _LOGGER.warning(msg) - # truncate service list - - self.folders: List[FolderObservation] = folders - # add empty folder observation without `where` parameter that will always return default (blank) observations - while len(self.folders) < num_folders_per_node: - self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) - while len(self.folders) > num_folders_per_node: - truncated_folder = self.folders.pop() - msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" - _LOGGER.warning(msg) - - self.network_interfaces: List[NicObservation] = network_interfaces - while len(self.network_interfaces) < num_nics_per_node: - self.network_interfaces.append(NicObservation()) - while len(self.network_interfaces) > num_nics_per_node: - truncated_nic = self.network_interfaces.pop() - msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" - _LOGGER.warning(msg) - - self.logon_status: bool = logon_status - - self.default_observation: Dict = { - "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, - "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, - "NETWORK_INTERFACES": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, - "operating_status": 0, - } - if self.logon_status: - self.default_observation["logon_status"] = 0 - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - node_state = access_from_nested_dict(state, self.where) - if node_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - obs = {} - obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} - obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} - obs["operating_status"] = node_state["operating_state"] - obs["NETWORK_INTERFACES"] = { - i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) - } - - if self.logon_status: - obs["logon_status"] = 0 - - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - space_shape = { - "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), - "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), - "operating_status": spaces.Discrete(5), - "NETWORK_INTERFACES": spaces.Dict( - {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} - ), - } - if self.logon_status: - space_shape["logon_status"] = spaces.Discrete(3) - - return spaces.Dict(space_shape) - - @classmethod - def from_config( - cls, - config: Dict, - game: "PrimaiteGame", - parent_where: Optional[List[str]] = None, - num_services_per_node: int = 2, - num_folders_per_node: int = 2, - num_files_per_folder: int = 2, - num_nics_per_node: int = 2, - ) -> "NodeObservation": - """Create node observation from a config. Also creates child service, folder and NIC observations. - - :param config: Dictionary containing the configuration for this node observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :param parent_where: Where in the simulation state dictionary to find the information about this node's parent - network. A typical location for it would be: ['network',] - :type parent_where: Optional[List[str]] - :param num_services_per_node: How many spaces for services are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_services_per_node: int, optional - :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static - observation size) , defaults to 2 - :type num_folders_per_node: int, optional - :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static - observation size) , defaults to 2 - :type num_files_per_folder: int, optional - :return: Constructed node observation - :rtype: NodeObservation - """ - node_hostname = config["node_hostname"] - if parent_where is None: - where = ["network", "nodes", node_hostname] - else: - where = parent_where + ["nodes", node_hostname] - - svc_configs = config.get("services", {}) - services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] - folder_configs = config.get("folders", {}) - folders = [ - FolderObservation.from_config( - config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder - ) - for c in folder_configs - ] - # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. - nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] - network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] - logon_status = config.get("logon_status", False) - return cls( - where=where, - services=services, - folders=folders, - network_interfaces=network_interfaces, - logon_status=logon_status, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - - -class AclObservation(AbstractObservation): - """Observation of an Access Control List (ACL) in the network.""" - - # TODO: should where be optional, and we can use where=None to pad the observation space? - # definitely the current approach does not support tracking files that aren't specified by name, for example - # if a file is created at runtime, we have currently got no way of telling the observation space to track it. - # this needs adding, but not for the MVP. - def __init__( - self, - node_ip_to_id: Dict[str, int], - ports: List[int], - protocols: List[str], - where: Optional[Tuple[str]] = None, - num_rules: int = 10, - ) -> None: - """Initialise ACL observation. - - :param node_ip_to_id: Mapping between IP address and ID. - :type node_ip_to_id: Dict[str, int] - :param ports: List of ports which are part of the game that define the ordering when converting to an ID - :type ports: List[int] - :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID - :type protocols: list[str] - :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical - example may look like this: - ['network','nodes',,'acl','acl'] - :type where: Optional[Tuple[str]], optional - :param num_rules: , defaults to 10 - :type num_rules: int, optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - self.num_rules: int = num_rules - self.node_to_id: Dict[str, int] = node_ip_to_id - "List of node IP addresses, order in this list determines how they are converted to an ID" - self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} - "List of ports which are part of the game that define the ordering when converting to an ID" - self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} - "List of protocols which are part of the game, defines ordering when converting to an ID" - self.default_observation: Dict = { - i - + 1: { - "position": i, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - for i in range(self.num_rules) - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - acl_state: Dict = access_from_nested_dict(state, self.where) - if acl_state is NOT_PRESENT_IN_STATE: - return self.default_observation - - # TODO: what if the ACL has more rules than num of max rules for obs space - obs = {} - acl_items = dict(acl_state.items()) - i = 1 # don't show rule 0 for compatibility reasons. - while i < self.num_rules + 1: - rule_state = acl_items[i] - if rule_state is None: - obs[i] = { - "position": i - 1, - "permission": 0, - "source_node_id": 0, - "source_port": 0, - "dest_node_id": 0, - "dest_port": 0, - "protocol": 0, - } - else: - src_ip = rule_state["src_ip_address"] - src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] - dst_ip = rule_state["dst_ip_address"] - dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] - src_port = rule_state["src_port"] - src_port_id = 1 if src_port is None else self.port_to_id[src_port] - dst_port = rule_state["dst_port"] - dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] - protocol = rule_state["protocol"] - protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] - obs[i] = { - "position": i - 1, - "permission": rule_state["action"], - "source_node_id": src_node_id, - "source_port": src_port_id, - "dest_node_id": dst_node_ip, - "dest_port": dst_port_id, - "protocol": protocol_id, - } - i += 1 - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape. - - :return: Gymnasium space - :rtype: spaces.Space - """ - return spaces.Dict( - { - i - + 1: spaces.Dict( - { - "position": spaces.Discrete(self.num_rules), - "permission": spaces.Discrete(3), - # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) - "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "source_port": spaces.Discrete(len(self.port_to_id) + 2), - "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), - "dest_port": spaces.Discrete(len(self.port_to_id) + 2), - "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), - } - ) - for i in range(self.num_rules) - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": - """Generate ACL observation from a config. - - :param config: Dictionary containing the configuration for this ACL observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Observation object - :rtype: AclObservation - """ - max_acl_rules = config["options"]["max_acl_rules"] - node_ip_to_idx = {} - for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): - node_ref = ip_map_config["node_hostname"] - nic_num = ip_map_config["nic_num"] - node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] - nic_obj = node_obj.network_interface[nic_num] - node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 - - router_hostname = config["router_hostname"] - return cls( - node_ip_to_id=node_ip_to_idx, - ports=game.options.ports, - protocols=game.options.protocols, - where=["network", "nodes", router_hostname, "acl", "acl"], - num_rules=max_acl_rules, - ) - - -class NullObservation(AbstractObservation): - """Null observation, returns a single 0 value for the observation space.""" - - def __init__(self, where: Optional[List[str]] = None): - """Initialise null observation.""" - self.default_observation: Dict = {} - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - return 0 - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Discrete(1) - - @classmethod - def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": - """ - Create null observation from a config. - - The parameters are ignored, they are here to match the signature of the other observation classes. - """ - return cls() - - -class ICSObservation(NullObservation): - """ICS observation placeholder, currently not implemented so always returns a single 0.""" - - pass - - -class UC2BlueObservation(AbstractObservation): - """Container for all observations used by the blue agent in UC2. - - TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler - for the purpose of compiling several observation components. - """ - - def __init__( - self, - nodes: List[NodeObservation], - links: List[LinkObservation], - acl: AclObservation, - ics: ICSObservation, - where: Optional[List[str]] = None, - ) -> None: - """Initialise UC2 blue observation. - - :param nodes: List of node observations - :type nodes: List[NodeObservation] - :param links: List of link observations - :type links: List[LinkObservation] - :param acl: The Access Control List observation - :type acl: AclObservation - :param ics: The ICS observation - :type ics: ICSObservation - :param where: Where in the simulation state dict to find information. Not used in this particular observation - because it only compiles other observations and doesn't contribute any new information, defaults to None - :type where: Optional[List[str]], optional - """ - super().__init__() - self.where: Optional[Tuple[str]] = where - - self.nodes: List[NodeObservation] = nodes - self.links: List[LinkObservation] = links - self.acl: AclObservation = acl - self.ics: ICSObservation = ics - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, - "ACL": self.acl.default_observation, - "ICS": self.ics.default_observation, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - :return: Observation - :rtype: Dict - """ - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} - obs["ACL"] = self.acl.observe(state) - obs["ICS"] = self.ics.observe(state) - - return obs - - @property - def space(self) -> spaces.Space: - """ - Gymnasium space object describing the observation space shape. - - :return: Space - :rtype: spaces.Space - """ - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), - "ACL": self.acl.space, - "ICS": self.ics.space, - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": - """Create UC2 blue observation from a config. - - :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, - links, ACL and ICS observations. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - :return: Constructed UC2 blue observation - :rtype: UC2BlueObservation - """ - node_configs = config["nodes"] - - num_services_per_node = config["num_services_per_node"] - num_folders_per_node = config["num_folders_per_node"] - num_files_per_folder = config["num_files_per_folder"] - num_nics_per_node = config["num_nics_per_node"] - nodes = [ - NodeObservation.from_config( - config=n, - game=game, - num_services_per_node=num_services_per_node, - num_folders_per_node=num_folders_per_node, - num_files_per_folder=num_files_per_folder, - num_nics_per_node=num_nics_per_node, - ) - for n in node_configs - ] - - link_configs = config["links"] - links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] - - acl_config = config["acl"] - acl = AclObservation.from_config(config=acl_config, game=game) - - ics_config = config["ics"] - ics = ICSObservation.from_config(config=ics_config, game=game) - new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) - return new - - -class UC2RedObservation(AbstractObservation): - """Container for all observations used by the red agent in UC2.""" - - def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: - super().__init__() - self.where: Optional[List[str]] = where - self.nodes: List[NodeObservation] = nodes - - self.default_observation: Dict = { - "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, - } - - def observe(self, state: Dict) -> Dict: - """Generate observation based on the current state of the simulation.""" - if self.where is None: - return self.default_observation - - obs = {} - obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} - return obs - - @property - def space(self) -> spaces.Space: - """Gymnasium space object describing the observation space shape.""" - return spaces.Dict( - { - "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), - } - ) - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": - """ - Create UC2 red observation from a config. - - :param config: Dictionary containing the configuration for this UC2 red observation. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - """ - node_configs = config["nodes"] - nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] - return cls(nodes=nodes, where=["network"]) - - -class UC2GreenObservation(NullObservation): - """Green agent observation. As the green agent's actions don't depend on the observation, this is empty.""" - - pass - - -class ObservationManager: - """ - Manage the observations of an Agent. - - The observation space has the purpose of: - 1. Reading the outputted state from the PrimAITE Simulation. - 2. Selecting parts of the simulation state that are requested by the simulation config - 3. Formatting this information so an agent can use it to make decisions. - """ - - # TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed - # to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next - # refactor. - - def __init__(self, observation: AbstractObservation) -> None: - """Initialise observation space. - - :param observation: Observation object - :type observation: AbstractObservation - """ - self.obs: AbstractObservation = observation - self.current_observation: ObsType - - def update(self, state: Dict) -> Dict: - """ - Generate observation based on the current state of the simulation. - - :param state: Simulation state dictionary - :type state: Dict - """ - self.current_observation = self.obs.observe(state) - return self.current_observation - - @property - def space(self) -> None: - """Gymnasium space object describing the observation space shape.""" - return self.obs.space - - @classmethod - def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": - """Create observation space from a config. - - :param config: Dictionary containing the configuration for this observation space. - It should contain the key 'type' which selects which observation class to use (from a choice of: - UC2BlueObservation, UC2RedObservation, UC2GreenObservation) - The other key is 'options' which are passed to the constructor of the selected observation class. - :type config: Dict - :param game: Reference to the PrimaiteGame object that spawned this observation. - :type game: PrimaiteGame - """ - if config["type"] == "UC2BlueObservation": - return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2RedObservation": - return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) - elif config["type"] == "UC2GreenObservation": - return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) - else: - raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/observations/__init__.py b/src/primaite/game/agent/observations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/agent/observations/agent_observations.py b/src/primaite/game/agent/observations/agent_observations.py new file mode 100644 index 00000000..70a83881 --- /dev/null +++ b/src/primaite/game/agent/observations/agent_observations.py @@ -0,0 +1,188 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.game.agent.observations.observations import ( + AbstractObservation, + AclObservation, + ICSObservation, + LinkObservation, + NullObservation, +) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class UC2BlueObservation(AbstractObservation): + """Container for all observations used by the blue agent in UC2. + + TODO: there's no real need for a UC2 blue container class, we should be able to simply use the observation handler + for the purpose of compiling several observation components. + """ + + def __init__( + self, + nodes: List[NodeObservation], + links: List[LinkObservation], + acl: AclObservation, + ics: ICSObservation, + where: Optional[List[str]] = None, + ) -> None: + """Initialise UC2 blue observation. + + :param nodes: List of node observations + :type nodes: List[NodeObservation] + :param links: List of link observations + :type links: List[LinkObservation] + :param acl: The Access Control List observation + :type acl: AclObservation + :param ics: The ICS observation + :type ics: ICSObservation + :param where: Where in the simulation state dict to find information. Not used in this particular observation + because it only compiles other observations and doesn't contribute any new information, defaults to None + :type where: Optional[List[str]], optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + self.nodes: List[NodeObservation] = nodes + self.links: List[LinkObservation] = links + self.acl: AclObservation = acl + self.ics: ICSObservation = ics + + self.default_observation: Dict = { + "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, + "LINKS": {i + 1: l.default_observation for i, l in enumerate(self.links)}, + "ACL": self.acl.default_observation, + "ICS": self.ics.default_observation, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + obs = {} + obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} + obs["LINKS"] = {i + 1: link.observe(state) for i, link in enumerate(self.links)} + obs["ACL"] = self.acl.observe(state) + obs["ICS"] = self.ics.observe(state) + + return obs + + @property + def space(self) -> spaces.Space: + """ + Gymnasium space object describing the observation space shape. + + :return: Space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), + "LINKS": spaces.Dict({i + 1: link.space for i, link in enumerate(self.links)}), + "ACL": self.acl.space, + "ICS": self.ics.space, + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2BlueObservation": + """Create UC2 blue observation from a config. + + :param config: Dictionary containing the configuration for this UC2 blue observation. This includes the nodes, + links, ACL and ICS observations. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Constructed UC2 blue observation + :rtype: UC2BlueObservation + """ + node_configs = config["nodes"] + + num_services_per_node = config["num_services_per_node"] + num_folders_per_node = config["num_folders_per_node"] + num_files_per_folder = config["num_files_per_folder"] + num_nics_per_node = config["num_nics_per_node"] + nodes = [ + NodeObservation.from_config( + config=n, + game=game, + num_services_per_node=num_services_per_node, + num_folders_per_node=num_folders_per_node, + num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, + ) + for n in node_configs + ] + + link_configs = config["links"] + links = [LinkObservation.from_config(config=link, game=game) for link in link_configs] + + acl_config = config["acl"] + acl = AclObservation.from_config(config=acl_config, game=game) + + ics_config = config["ics"] + ics = ICSObservation.from_config(config=ics_config, game=game) + new = cls(nodes=nodes, links=links, acl=acl, ics=ics, where=["network"]) + return new + + +class UC2RedObservation(AbstractObservation): + """Container for all observations used by the red agent in UC2.""" + + def __init__(self, nodes: List[NodeObservation], where: Optional[List[str]] = None) -> None: + super().__init__() + self.where: Optional[List[str]] = where + self.nodes: List[NodeObservation] = nodes + + self.default_observation: Dict = { + "NODES": {i + 1: n.default_observation for i, n in enumerate(self.nodes)}, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation.""" + if self.where is None: + return self.default_observation + + obs = {} + obs["NODES"] = {i + 1: node.observe(state) for i, node in enumerate(self.nodes)} + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "NODES": spaces.Dict({i + 1: node.space for i, node in enumerate(self.nodes)}), + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "UC2RedObservation": + """ + Create UC2 red observation from a config. + + :param config: Dictionary containing the configuration for this UC2 red observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + """ + node_configs = config["nodes"] + nodes = [NodeObservation.from_config(config=cfg, game=game) for cfg in node_configs] + return cls(nodes=nodes, where=["network"]) + + +class UC2GreenObservation(NullObservation): + """Green agent observation. As the green agent's actions don't depend on the observation, this is empty.""" + + pass diff --git a/src/primaite/game/agent/observations/file_system_observations.py b/src/primaite/game/agent/observations/file_system_observations.py new file mode 100644 index 00000000..277bc51f --- /dev/null +++ b/src/primaite/game/agent/observations/file_system_observations.py @@ -0,0 +1,177 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class FileObservation(AbstractObservation): + """Observation of a file on a node in the network.""" + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """ + Initialise file observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',,'files',] + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + self.default_observation: spaces.Space = {"health_status": 0} + "Default observation is what should be returned when the file doesn't exist, e.g. after it has been deleted." + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + file_state = access_from_nested_dict(state, self.where) + if file_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return {"health_status": file_state["visible_status"]} + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict({"health_status": spaces.Discrete(6)}) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: List[str] = None) -> "FileObservation": + """Create file observation from a config. + + :param config: Dictionary containing the configuration for this file observation. + :type config: Dict + :param game: _description_ + :type game: PrimaiteGame + :param parent_where: _description_, defaults to None + :type parent_where: _type_, optional + :return: _description_ + :rtype: _type_ + """ + return cls(where=parent_where + ["files", config["file_name"]]) + + +class FolderObservation(AbstractObservation): + """Folder observation, including files inside of the folder.""" + + def __init__( + self, where: Optional[Tuple[str]] = None, files: List[FileObservation] = [], num_files_per_folder: int = 2 + ) -> None: + """Initialise folder Observation, including files inside the folder. + + :param where: Where in the simulation state dictionary to find the relevant information for this folder. + A typical location for a file looks like this: + ['network','nodes',,'file_system', 'folders',] + :type where: Optional[List[str]] + :param max_files: As size of the space must remain static, define max files that can be in this folder + , defaults to 5 + :type max_files: int, optional + :param file_positions: Defines the positioning within the observation space of particular files. This ensures + that even if new files are created, the existing files will always occupy the same space in the observation + space. The keys must be between 1 and max_files. Providing file_positions will reserve a spot in the + observation space for a file with that name, even if it's temporarily deleted, if it reappears with the same + name, it will take the position defined in this dict. Defaults to {} + :type file_positions: Dict[int, str], optional + """ + super().__init__() + + self.where: Optional[Tuple[str]] = where + + self.files: List[FileObservation] = files + while len(self.files) < num_files_per_folder: + self.files.append(FileObservation()) + while len(self.files) > num_files_per_folder: + truncated_file = self.files.pop() + msg = f"Too many files in folder observation. Truncating file {truncated_file}" + _LOGGER.warning(msg) + + self.default_observation = { + "health_status": 0, + "FILES": {i + 1: f.default_observation for i, f in enumerate(self.files)}, + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + folder_state = access_from_nested_dict(state, self.where) + if folder_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + health_status = folder_state["health_status"] + + obs = {} + + obs["health_status"] = health_status + obs["FILES"] = {i + 1: file.observe(state) for i, file in enumerate(self.files)} + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + "health_status": spaces.Discrete(6), + "FILES": spaces.Dict({i + 1: f.space for i, f in enumerate(self.files)}), + } + ) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]], num_files_per_folder: int = 2 + ) -> "FolderObservation": + """Create folder observation from a config. Also creates child file observations. + + :param config: Dictionary containing the configuration for this folder observation. Includes the name of the + folder and the files inside of it. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this folder's + parent node. A typical location for a node ``where`` can be: + ['network','nodes',,'file_system'] + :type parent_where: Optional[List[str]] + :param num_files_per_folder: How many spaces for files are in this folder observation (to preserve static + observation size) , defaults to 2 + :type num_files_per_folder: int, optional + :return: Constructed folder observation + :rtype: FolderObservation + """ + where = parent_where + ["folders", config["folder_name"]] + + file_configs = config["files"] + files = [FileObservation.from_config(config=f, game=game, parent_where=where) for f in file_configs] + + return cls(where=where, files=files, num_files_per_folder=num_files_per_folder) diff --git a/src/primaite/game/agent/observations/nic_observations.py b/src/primaite/game/agent/observations/nic_observations.py new file mode 100644 index 00000000..de83e03a --- /dev/null +++ b/src/primaite/game/agent/observations/nic_observations.py @@ -0,0 +1,188 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +from primaite.simulator.network.nmne import CAPTURE_NMNE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class NicObservation(AbstractObservation): + """Observation of a Network Interface Card (NIC) in the network.""" + + low_nmne_threshold: int = 0 + """The minimum number of malicious network events to be considered low.""" + med_nmne_threshold: int = 5 + """The minimum number of malicious network events to be considered medium.""" + high_nmne_threshold: int = 10 + """The minimum number of malicious network events to be considered high.""" + + global CAPTURE_NMNE + + @property + def default_observation(self) -> Dict: + """The default NIC observation dict.""" + data = {"nic_status": 0} + if CAPTURE_NMNE: + data.update({"NMNE": {"inbound": 0, "outbound": 0}}) + + return data + + def __init__( + self, + where: Optional[Tuple[str]] = None, + low_nmne_threshold: Optional[int] = 0, + med_nmne_threshold: Optional[int] = 5, + high_nmne_threshold: Optional[int] = 10, + ) -> None: + """Initialise NIC observation. + + :param where: Where in the simulation state dictionary to find the relevant information for this NIC. A typical + example may look like this: + ['network','nodes',,'NICs',] + If None, this denotes that the NIC does not exist and the observation will be populated with zeroes. + :type where: Optional[Tuple[str]], optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + global CAPTURE_NMNE + if CAPTURE_NMNE: + self.nmne_inbound_last_step: int = 0 + """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets + us find the difference.""" + self.nmne_outbound_last_step: int = 0 + """NMNEs persist for the whole episode, but we want to count per step. Keeping track of last step count lets + us find the difference.""" + + if low_nmne_threshold or med_nmne_threshold or high_nmne_threshold: + self._validate_nmne_categories( + low_nmne_threshold=low_nmne_threshold, + med_nmne_threshold=med_nmne_threshold, + high_nmne_threshold=high_nmne_threshold, + ) + + def _validate_nmne_categories( + self, low_nmne_threshold: int = 0, med_nmne_threshold: int = 5, high_nmne_threshold: int = 10 + ): + """ + Validates the nmne threshold config. + + If the configuration is valid, the thresholds will be set, otherwise, an exception is raised. + + :param: low_nmne_threshold: The minimum number of malicious network events to be considered low + :param: med_nmne_threshold: The minimum number of malicious network events to be considered medium + :param: high_nmne_threshold: The minimum number of malicious network events to be considered high + """ + if high_nmne_threshold <= med_nmne_threshold: + raise Exception( + f"nmne_categories: high nmne count ({high_nmne_threshold}) must be greater " + f"than medium nmne count ({med_nmne_threshold})" + ) + + if med_nmne_threshold <= low_nmne_threshold: + raise Exception( + f"nmne_categories: medium nmne count ({med_nmne_threshold}) must be greater " + f"than low nmne count ({low_nmne_threshold})" + ) + + self.high_nmne_threshold = high_nmne_threshold + self.med_nmne_threshold = med_nmne_threshold + self.low_nmne_threshold = low_nmne_threshold + + def _categorise_mne_count(self, nmne_count: int) -> int: + """ + Categorise the number of Malicious Network Events (NMNEs) into discrete bins. + + This helps in classifying the severity or volume of MNEs into manageable levels for the agent. + + Bins are defined as follows: + - 0: No MNEs detected (0 events). + - 1: Low number of MNEs (default 1-5 events). + - 2: Moderate number of MNEs (default 6-10 events). + - 3: High number of MNEs (default more than 10 events). + + :param nmne_count: Number of MNEs detected. + :return: Bin number corresponding to the number of MNEs. Returns 0, 1, 2, or 3 based on the detected MNE count. + """ + if nmne_count > self.high_nmne_threshold: + return 3 + elif nmne_count > self.med_nmne_threshold: + return 2 + elif nmne_count > self.low_nmne_threshold: + return 1 + return 0 + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + nic_state = access_from_nested_dict(state, self.where) + + if nic_state is NOT_PRESENT_IN_STATE: + return self.default_observation + else: + obs_dict = {"nic_status": 1 if nic_state["enabled"] else 2} + if CAPTURE_NMNE: + obs_dict.update({"NMNE": {}}) + direction_dict = nic_state["nmne"].get("direction", {}) + inbound_keywords = direction_dict.get("inbound", {}).get("keywords", {}) + inbound_count = inbound_keywords.get("*", 0) + outbound_keywords = direction_dict.get("outbound", {}).get("keywords", {}) + outbound_count = outbound_keywords.get("*", 0) + obs_dict["NMNE"]["inbound"] = self._categorise_mne_count(inbound_count - self.nmne_inbound_last_step) + obs_dict["NMNE"]["outbound"] = self._categorise_mne_count(outbound_count - self.nmne_outbound_last_step) + self.nmne_inbound_last_step = inbound_count + self.nmne_outbound_last_step = outbound_count + return obs_dict + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + space = spaces.Dict({"nic_status": spaces.Discrete(3)}) + + if CAPTURE_NMNE: + space["NMNE"] = spaces.Dict({"inbound": spaces.Discrete(4), "outbound": spaces.Discrete(4)}) + + return space + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]]) -> "NicObservation": + """Create NIC observation from a config. + + :param config: Dictionary containing the configuration for this NIC observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this NIC's parent + node. A typical location for a node ``where`` can be: ['network','nodes',] + :type parent_where: Optional[List[str]] + :return: Constructed NIC observation + :rtype: NicObservation + """ + low_nmne_threshold = None + med_nmne_threshold = None + high_nmne_threshold = None + + if game and game.options and game.options.thresholds and game.options.thresholds.get("nmne"): + threshold = game.options.thresholds["nmne"] + + low_nmne_threshold = int(threshold.get("low")) if threshold.get("low") is not None else None + med_nmne_threshold = int(threshold.get("medium")) if threshold.get("medium") is not None else None + high_nmne_threshold = int(threshold.get("high")) if threshold.get("high") is not None else None + + return cls( + where=parent_where + ["NICs", config["nic_num"]], + low_nmne_threshold=low_nmne_threshold, + med_nmne_threshold=med_nmne_threshold, + high_nmne_threshold=high_nmne_threshold, + ) diff --git a/src/primaite/game/agent/observations/node_observations.py b/src/primaite/game/agent/observations/node_observations.py new file mode 100644 index 00000000..94f0974b --- /dev/null +++ b/src/primaite/game/agent/observations/node_observations.py @@ -0,0 +1,200 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.observations.file_system_observations import FolderObservation +from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.observations.software_observation import ServiceObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class NodeObservation(AbstractObservation): + """Observation of a node in the network. Includes services, folders and NICs.""" + + def __init__( + self, + where: Optional[Tuple[str]] = None, + services: List[ServiceObservation] = [], + folders: List[FolderObservation] = [], + network_interfaces: List[NicObservation] = [], + logon_status: bool = False, + num_services_per_node: int = 2, + num_folders_per_node: int = 2, + num_files_per_folder: int = 2, + num_nics_per_node: int = 2, + ) -> None: + """ + Configurable observation for a node in the simulation. + + :param where: Where in the simulation state dictionary for find relevant information for this observation. + A typical location for a node looks like this: + ['network','nodes',]. If empty list, a default null observation will be output, defaults to [] + :type where: List[str], optional + :param services: Mapping between position in observation space and service name, defaults to {} + :type services: Dict[int,str], optional + :param max_services: Max number of services that can be presented in observation space for this node + , defaults to 2 + :type max_services: int, optional + :param folders: Mapping between position in observation space and folder name, defaults to {} + :type folders: Dict[int,str], optional + :param max_folders: Max number of folders in this node's obs space, defaults to 2 + :type max_folders: int, optional + :param network_interfaces: Mapping between position in observation space and NIC idx, defaults to {} + :type network_interfaces: Dict[int,str], optional + :param max_nics: Max number of network interfaces in this node's obs space, defaults to 5 + :type max_nics: int, optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + self.services: List[ServiceObservation] = services + while len(self.services) < num_services_per_node: + # add empty service observation without `where` parameter so it always returns default (blank) observation + self.services.append(ServiceObservation()) + while len(self.services) > num_services_per_node: + truncated_service = self.services.pop() + msg = f"Too many services in Node observation space for node. Truncating service {truncated_service.where}" + _LOGGER.warning(msg) + # truncate service list + + self.folders: List[FolderObservation] = folders + # add empty folder observation without `where` parameter that will always return default (blank) observations + while len(self.folders) < num_folders_per_node: + self.folders.append(FolderObservation(num_files_per_folder=num_files_per_folder)) + while len(self.folders) > num_folders_per_node: + truncated_folder = self.folders.pop() + msg = f"Too many folders in Node observation for node. Truncating service {truncated_folder.where[-1]}" + _LOGGER.warning(msg) + + self.network_interfaces: List[NicObservation] = network_interfaces + while len(self.network_interfaces) < num_nics_per_node: + self.network_interfaces.append(NicObservation()) + while len(self.network_interfaces) > num_nics_per_node: + truncated_nic = self.network_interfaces.pop() + msg = f"Too many NICs in Node observation for node. Truncating service {truncated_nic.where[-1]}" + _LOGGER.warning(msg) + + self.logon_status: bool = logon_status + + self.default_observation: Dict = { + "SERVICES": {i + 1: s.default_observation for i, s in enumerate(self.services)}, + "FOLDERS": {i + 1: f.default_observation for i, f in enumerate(self.folders)}, + "NICS": {i + 1: n.default_observation for i, n in enumerate(self.network_interfaces)}, + "operating_status": 0, + } + if self.logon_status: + self.default_observation["logon_status"] = 0 + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + node_state = access_from_nested_dict(state, self.where) + if node_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + obs = {} + obs["SERVICES"] = {i + 1: service.observe(state) for i, service in enumerate(self.services)} + obs["FOLDERS"] = {i + 1: folder.observe(state) for i, folder in enumerate(self.folders)} + obs["operating_status"] = node_state["operating_state"] + obs["NICS"] = { + i + 1: network_interface.observe(state) for i, network_interface in enumerate(self.network_interfaces) + } + + if self.logon_status: + obs["logon_status"] = 0 + + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + space_shape = { + "SERVICES": spaces.Dict({i + 1: service.space for i, service in enumerate(self.services)}), + "FOLDERS": spaces.Dict({i + 1: folder.space for i, folder in enumerate(self.folders)}), + "operating_status": spaces.Discrete(5), + "NICS": spaces.Dict( + {i + 1: network_interface.space for i, network_interface in enumerate(self.network_interfaces)} + ), + } + if self.logon_status: + space_shape["logon_status"] = spaces.Discrete(3) + + return spaces.Dict(space_shape) + + @classmethod + def from_config( + cls, + config: Dict, + game: "PrimaiteGame", + parent_where: Optional[List[str]] = None, + num_services_per_node: int = 2, + num_folders_per_node: int = 2, + num_files_per_folder: int = 2, + num_nics_per_node: int = 2, + ) -> "NodeObservation": + """Create node observation from a config. Also creates child service, folder and NIC observations. + + :param config: Dictionary containing the configuration for this node observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary to find the information about this node's parent + network. A typical location for it would be: ['network',] + :type parent_where: Optional[List[str]] + :param num_services_per_node: How many spaces for services are in this node observation (to preserve static + observation size) , defaults to 2 + :type num_services_per_node: int, optional + :param num_folders_per_node: How many spaces for folders are in this node observation (to preserve static + observation size) , defaults to 2 + :type num_folders_per_node: int, optional + :param num_files_per_folder: How many spaces for files are in the folder observations (to preserve static + observation size) , defaults to 2 + :type num_files_per_folder: int, optional + :return: Constructed node observation + :rtype: NodeObservation + """ + node_hostname = config["node_hostname"] + if parent_where is None: + where = ["network", "nodes", node_hostname] + else: + where = parent_where + ["nodes", node_hostname] + + svc_configs = config.get("services", {}) + services = [ServiceObservation.from_config(config=c, game=game, parent_where=where) for c in svc_configs] + folder_configs = config.get("folders", {}) + folders = [ + FolderObservation.from_config( + config=c, game=game, parent_where=where + ["file_system"], num_files_per_folder=num_files_per_folder + ) + for c in folder_configs + ] + # create some configs for the NIC observation in the format {"nic_num":1}, {"nic_num":2}, {"nic_num":3}, etc. + nic_configs = [{"nic_num": i for i in range(num_nics_per_node)}] + network_interfaces = [NicObservation.from_config(config=c, game=game, parent_where=where) for c in nic_configs] + logon_status = config.get("logon_status", False) + return cls( + where=where, + services=services, + folders=folders, + network_interfaces=network_interfaces, + logon_status=logon_status, + num_services_per_node=num_services_per_node, + num_folders_per_node=num_folders_per_node, + num_files_per_folder=num_files_per_folder, + num_nics_per_node=num_nics_per_node, + ) diff --git a/src/primaite/game/agent/observations/observation_manager.py b/src/primaite/game/agent/observations/observation_manager.py new file mode 100644 index 00000000..400345fa --- /dev/null +++ b/src/primaite/game/agent/observations/observation_manager.py @@ -0,0 +1,73 @@ +from typing import Dict, TYPE_CHECKING + +from gymnasium.core import ObsType + +from primaite.game.agent.observations.agent_observations import ( + UC2BlueObservation, + UC2GreenObservation, + UC2RedObservation, +) +from primaite.game.agent.observations.observations import AbstractObservation + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class ObservationManager: + """ + Manage the observations of an Agent. + + The observation space has the purpose of: + 1. Reading the outputted state from the PrimAITE Simulation. + 2. Selecting parts of the simulation state that are requested by the simulation config + 3. Formatting this information so an agent can use it to make decisions. + """ + + # TODO: Dear code reader: This class currently doesn't do much except hold an observation object. It will be changed + # to have more of it's own behaviour, and it will replace UC2BlueObservation and UC2RedObservation during the next + # refactor. + + def __init__(self, observation: AbstractObservation) -> None: + """Initialise observation space. + + :param observation: Observation object + :type observation: AbstractObservation + """ + self.obs: AbstractObservation = observation + self.current_observation: ObsType + + def update(self, state: Dict) -> Dict: + """ + Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + """ + self.current_observation = self.obs.observe(state) + return self.current_observation + + @property + def space(self) -> None: + """Gymnasium space object describing the observation space shape.""" + return self.obs.space + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "ObservationManager": + """Create observation space from a config. + + :param config: Dictionary containing the configuration for this observation space. + It should contain the key 'type' which selects which observation class to use (from a choice of: + UC2BlueObservation, UC2RedObservation, UC2GreenObservation) + The other key is 'options' which are passed to the constructor of the selected observation class. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + """ + if config["type"] == "UC2BlueObservation": + return cls(UC2BlueObservation.from_config(config.get("options", {}), game=game)) + elif config["type"] == "UC2RedObservation": + return cls(UC2RedObservation.from_config(config.get("options", {}), game=game)) + elif config["type"] == "UC2GreenObservation": + return cls(UC2GreenObservation.from_config(config.get("options", {}), game=game)) + else: + raise ValueError("Observation space type invalid") diff --git a/src/primaite/game/agent/observations/observations.py b/src/primaite/game/agent/observations/observations.py new file mode 100644 index 00000000..6236b00d --- /dev/null +++ b/src/primaite/game/agent/observations/observations.py @@ -0,0 +1,309 @@ +"""Manages the observation space for the agent.""" +from abc import ABC, abstractmethod +from ipaddress import IPv4Address +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite import getLogger +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +_LOGGER = getLogger(__name__) + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class AbstractObservation(ABC): + """Abstract class for an observation space component.""" + + @abstractmethod + def observe(self, state: Dict) -> Any: + """ + Return an observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Any + """ + pass + + @property + @abstractmethod + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space.""" + pass + + @classmethod + @abstractmethod + def from_config(cls, config: Dict, game: "PrimaiteGame"): + """Create this observation space component form a serialised format. + + The `game` parameter is for a the PrimaiteGame object that spawns this component. + """ + pass + + +class LinkObservation(AbstractObservation): + """Observation of a link in the network.""" + + default_observation: spaces.Space = {"PROTOCOLS": {"ALL": 0}} + "Default observation is what should be returned when the link doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise link observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'servics', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + link_state = access_from_nested_dict(state, self.where) + if link_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + bandwidth = link_state["bandwidth"] + load = link_state["current_load"] + if load == 0: + utilisation_category = 0 + else: + utilisation_fraction = load / bandwidth + # 0 is UNUSED, 1 is 0%-10%. 2 is 10%-20%. 3 is 20%-30%. And so on... 10 is exactly 100% + utilisation_category = int(utilisation_fraction * 9) + 1 + + # TODO: once the links support separte load per protocol, this needs amendment to reflect that. + return {"PROTOCOLS": {"ALL": min(utilisation_category, 10)}} + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict({"PROTOCOLS": spaces.Dict({"ALL": spaces.Discrete(11)})}) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "LinkObservation": + """Create link observation from a config. + + :param config: Dictionary containing the configuration for this link observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Constructed link observation + :rtype: LinkObservation + """ + return cls(where=["network", "links", game.ref_map_links[config["link_ref"]]]) + + +class AclObservation(AbstractObservation): + """Observation of an Access Control List (ACL) in the network.""" + + # TODO: should where be optional, and we can use where=None to pad the observation space? + # definitely the current approach does not support tracking files that aren't specified by name, for example + # if a file is created at runtime, we have currently got no way of telling the observation space to track it. + # this needs adding, but not for the MVP. + def __init__( + self, + node_ip_to_id: Dict[str, int], + ports: List[int], + protocols: List[str], + where: Optional[Tuple[str]] = None, + num_rules: int = 10, + ) -> None: + """Initialise ACL observation. + + :param node_ip_to_id: Mapping between IP address and ID. + :type node_ip_to_id: Dict[str, int] + :param ports: List of ports which are part of the game that define the ordering when converting to an ID + :type ports: List[int] + :param protocols: List of protocols which are part of the game, defines ordering when converting to an ID + :type protocols: list[str] + :param where: Where in the simulation state dictionary to find the relevant information for this ACL. A typical + example may look like this: + ['network','nodes',,'acl','acl'] + :type where: Optional[Tuple[str]], optional + :param num_rules: , defaults to 10 + :type num_rules: int, optional + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + self.num_rules: int = num_rules + self.node_to_id: Dict[str, int] = node_ip_to_id + "List of node IP addresses, order in this list determines how they are converted to an ID" + self.port_to_id: Dict[int, int] = {port: i + 2 for i, port in enumerate(ports)} + "List of ports which are part of the game that define the ordering when converting to an ID" + self.protocol_to_id: Dict[str, int] = {protocol: i + 2 for i, protocol in enumerate(protocols)} + "List of protocols which are part of the game, defines ordering when converting to an ID" + self.default_observation: Dict = { + i + + 1: { + "position": i, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + for i in range(self.num_rules) + } + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + acl_state: Dict = access_from_nested_dict(state, self.where) + if acl_state is NOT_PRESENT_IN_STATE: + return self.default_observation + + # TODO: what if the ACL has more rules than num of max rules for obs space + obs = {} + acl_items = dict(acl_state.items()) + i = 1 # don't show rule 0 for compatibility reasons. + while i < self.num_rules + 1: + rule_state = acl_items[i] + if rule_state is None: + obs[i] = { + "position": i - 1, + "permission": 0, + "source_node_id": 0, + "source_port": 0, + "dest_node_id": 0, + "dest_port": 0, + "protocol": 0, + } + else: + src_ip = rule_state["src_ip_address"] + src_node_id = 1 if src_ip is None else self.node_to_id[IPv4Address(src_ip)] + dst_ip = rule_state["dst_ip_address"] + dst_node_ip = 1 if dst_ip is None else self.node_to_id[IPv4Address(dst_ip)] + src_port = rule_state["src_port"] + src_port_id = 1 if src_port is None else self.port_to_id[src_port] + dst_port = rule_state["dst_port"] + dst_port_id = 1 if dst_port is None else self.port_to_id[dst_port] + protocol = rule_state["protocol"] + protocol_id = 1 if protocol is None else self.protocol_to_id[protocol] + obs[i] = { + "position": i - 1, + "permission": rule_state["action"], + "source_node_id": src_node_id, + "source_port": src_port_id, + "dest_node_id": dst_node_ip, + "dest_port": dst_port_id, + "protocol": protocol_id, + } + i += 1 + return obs + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape. + + :return: Gymnasium space + :rtype: spaces.Space + """ + return spaces.Dict( + { + i + + 1: spaces.Dict( + { + "position": spaces.Discrete(self.num_rules), + "permission": spaces.Discrete(3), + # adding two to lengths is to account for reserved values 0 (unused) and 1 (any) + "source_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "source_port": spaces.Discrete(len(self.port_to_id) + 2), + "dest_node_id": spaces.Discrete(len(set(self.node_to_id.values())) + 2), + "dest_port": spaces.Discrete(len(self.port_to_id) + 2), + "protocol": spaces.Discrete(len(self.protocol_to_id) + 2), + } + ) + for i in range(self.num_rules) + } + ) + + @classmethod + def from_config(cls, config: Dict, game: "PrimaiteGame") -> "AclObservation": + """Generate ACL observation from a config. + + :param config: Dictionary containing the configuration for this ACL observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :return: Observation object + :rtype: AclObservation + """ + max_acl_rules = config["options"]["max_acl_rules"] + node_ip_to_idx = {} + for ip_idx, ip_map_config in enumerate(config["ip_address_order"]): + node_ref = ip_map_config["node_hostname"] + nic_num = ip_map_config["nic_num"] + node_obj = game.simulation.network.nodes[game.ref_map_nodes[node_ref]] + nic_obj = node_obj.network_interface[nic_num] + node_ip_to_idx[nic_obj.ip_address] = ip_idx + 2 + + router_hostname = config["router_hostname"] + return cls( + node_ip_to_id=node_ip_to_idx, + ports=game.options.ports, + protocols=game.options.protocols, + where=["network", "nodes", router_hostname, "acl", "acl"], + num_rules=max_acl_rules, + ) + + +class NullObservation(AbstractObservation): + """Null observation, returns a single 0 value for the observation space.""" + + def __init__(self, where: Optional[List[str]] = None): + """Initialise null observation.""" + self.default_observation: Dict = {} + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation.""" + return 0 + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Discrete(1) + + @classmethod + def from_config(cls, config: Dict, game: Optional["PrimaiteGame"] = None) -> "NullObservation": + """ + Create null observation from a config. + + The parameters are ignored, they are here to match the signature of the other observation classes. + """ + return cls() + + +class ICSObservation(NullObservation): + """ICS observation placeholder, currently not implemented so always returns a single 0.""" + + pass diff --git a/src/primaite/game/agent/observations/software_observation.py b/src/primaite/game/agent/observations/software_observation.py new file mode 100644 index 00000000..6caf791c --- /dev/null +++ b/src/primaite/game/agent/observations/software_observation.py @@ -0,0 +1,163 @@ +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from gymnasium import spaces + +from primaite.game.agent.observations.observations import AbstractObservation +from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE + +if TYPE_CHECKING: + from primaite.game.game import PrimaiteGame + + +class ServiceObservation(AbstractObservation): + """Observation of a service in the network.""" + + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0} + "Default observation is what should be returned when the service doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise service observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'services', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + service_state = access_from_nested_dict(state, self.where) + if service_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": service_state["operating_state"], + "health_status": service_state["health_state_visible"], + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict({"operating_status": spaces.Discrete(7), "health_status": spaces.Discrete(5)}) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None + ) -> "ServiceObservation": + """Create service observation from a config. + + :param config: Dictionary containing the configuration for this service observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. + :type parent_where: Optional[List[str]], optional + :return: Constructed service observation + :rtype: ServiceObservation + """ + return cls(where=parent_where + ["services", config["service_name"]]) + + +class ApplicationObservation(AbstractObservation): + """Observation of an application in the network.""" + + default_observation: spaces.Space = {"operating_status": 0, "health_status": 0, "num_executions": 0} + "Default observation is what should be returned when the application doesn't exist." + + def __init__(self, where: Optional[Tuple[str]] = None) -> None: + """Initialise application observation. + + :param where: Store information about where in the simulation state dictionary to find the relevant information. + Optional. If None, this corresponds that the file does not exist and the observation will be populated with + zeroes. + + A typical location for a service looks like this: + `['network','nodes',,'applications', ]` + :type where: Optional[List[str]] + """ + super().__init__() + self.where: Optional[Tuple[str]] = where + + def observe(self, state: Dict) -> Dict: + """Generate observation based on the current state of the simulation. + + :param state: Simulation state dictionary + :type state: Dict + :return: Observation + :rtype: Dict + """ + if self.where is None: + return self.default_observation + + app_state = access_from_nested_dict(state, self.where) + if app_state is NOT_PRESENT_IN_STATE: + return self.default_observation + return { + "operating_status": app_state["operating_state"], + "health_status": app_state["health_state_visible"], + "num_executions": self._categorise_num_executions(app_state["num_executions"]), + } + + @property + def space(self) -> spaces.Space: + """Gymnasium space object describing the observation space shape.""" + return spaces.Dict( + { + "operating_status": spaces.Discrete(7), + "health_status": spaces.Discrete(6), + "num_executions": spaces.Discrete(4), + } + ) + + @classmethod + def from_config( + cls, config: Dict, game: "PrimaiteGame", parent_where: Optional[List[str]] = None + ) -> "ApplicationObservation": + """Create application observation from a config. + + :param config: Dictionary containing the configuration for this service observation. + :type config: Dict + :param game: Reference to the PrimaiteGame object that spawned this observation. + :type game: PrimaiteGame + :param parent_where: Where in the simulation state dictionary this service's parent node is located. Optional. + :type parent_where: Optional[List[str]], optional + :return: Constructed service observation + :rtype: ApplicationObservation + """ + return cls(where=parent_where + ["services", config["application_name"]]) + + @classmethod + def _categorise_num_executions(cls, num_executions: int) -> int: + """ + Categorise the number of executions of an application. + + Helps classify the number of application executions into different categories. + + Current categories: + - 0: Application is never executed + - 1: Application is executed a low number of times (1-5) + - 2: Application is executed often (6-10) + - 3: Application is executed a high number of times (more than 10) + + :param: num_executions: Number of times the application is executed + """ + if num_executions > 10: + return 3 + elif num_executions > 5: + return 2 + elif num_executions > 0: + return 1 + return 0 diff --git a/src/primaite/game/agent/rewards.py b/src/primaite/game/agent/rewards.py index b5d5f998..2201b09e 100644 --- a/src/primaite/game/agent/rewards.py +++ b/src/primaite/game/agent/rewards.py @@ -13,7 +13,7 @@ the structure: - type: DATABASE_FILE_INTEGRITY weight: 0.5 options: - node_ref: database_server + node_name: database_server folder_name: database file_name: database.db @@ -21,16 +21,21 @@ the structure: - type: WEB_SERVER_404_PENALTY weight: 0.5 options: - node_ref: web_server + node_name: web_server service_ref: web_server_database_client ``` """ from abc import abstractmethod -from typing import Dict, List, Tuple, Type +from typing import Callable, Dict, List, Optional, Tuple, Type, TYPE_CHECKING + +from typing_extensions import Never from primaite import getLogger from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE +if TYPE_CHECKING: + from primaite.game.agent.interface import AgentActionHistoryItem + _LOGGER = getLogger(__name__) @@ -38,7 +43,7 @@ class AbstractReward: """Base class for reward function components.""" @abstractmethod - def calculate(self, state: Dict) -> float: + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -58,7 +63,7 @@ class AbstractReward: class DummyReward(AbstractReward): """Dummy reward function component which always returns 0.""" - def calculate(self, state: Dict) -> float: + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the reward for the current state.""" return 0.0 @@ -98,7 +103,7 @@ class DatabaseFileIntegrity(AbstractReward): file_name, ] - def calculate(self, state: Dict) -> float: + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -106,7 +111,7 @@ class DatabaseFileIntegrity(AbstractReward): """ database_file_state = access_from_nested_dict(state, self.location_in_state) if database_file_state is NOT_PRESENT_IN_STATE: - _LOGGER.info( + _LOGGER.debug( f"Could not calculate {self.__class__} reward because " "simulation state did not contain enough information." ) @@ -153,7 +158,7 @@ class WebServer404Penalty(AbstractReward): """ self.location_in_state = ["network", "nodes", node_hostname, "services", service_name] - def calculate(self, state: Dict) -> float: + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the reward for the current state. :param state: The current state of the simulation. @@ -184,7 +189,7 @@ class WebServer404Penalty(AbstractReward): service_name = config.get("service_name") if not (node_hostname and service_name): msg = ( - f"{cls.__name__} could not be initialised from config because node_ref and service_ref were not " + f"{cls.__name__} could not be initialised from config because node_name and service_ref were not " "found in reward config." ) _LOGGER.warning(msg) @@ -203,19 +208,30 @@ class WebpageUnavailablePenalty(AbstractReward): :param node_hostname: Hostname of the node which has the web browser. :type node_hostname: str """ - self._node = node_hostname - self.location_in_state = ["network", "nodes", node_hostname, "applications", "WebBrowser"] + self._node: str = node_hostname + self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"] + self._last_request_failed: bool = False - def calculate(self, state: Dict) -> float: + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """ - Calculate the reward based on current simulation state. + Calculate the reward based on current simulation state, and the recent agent action. - :param state: The current state of the simulation. - :type state: Dict + When the green agent requests to execute the browser application, and that request fails, this reward + component will keep track of that information. In that case, it doesn't matter whether the last webpage + had a 200 status code, because there has been an unsuccessful request since. """ + if last_action_response.request == ["network", "node", self._node, "application", "WebBrowser", "execute"]: + self._last_request_failed = last_action_response.response.status != "success" + + # if agent couldn't even get as far as sending the request (because for example the node was off), then + # apply a penalty + if self._last_request_failed: + return -1.0 + + # If the last request did actually go through, then check if the webpage also loaded web_browser_state = access_from_nested_dict(state, self.location_in_state) if web_browser_state is NOT_PRESENT_IN_STATE or "history" not in web_browser_state: - _LOGGER.info( + _LOGGER.debug( "Web browser reward could not be calculated because the web browser history on node", f"{self._node} was not reported in the simulation state. Returning 0.0", ) @@ -242,15 +258,117 @@ class WebpageUnavailablePenalty(AbstractReward): return cls(node_hostname=node_hostname) +class GreenAdminDatabaseUnreachablePenalty(AbstractReward): + """Penalises the agent when the green db clients fail to connect to the database.""" + + def __init__(self, node_hostname: str) -> None: + """ + Initialise the reward component. + + :param node_hostname: Hostname of the node where the database client sits. + :type node_hostname: str + """ + self._node: str = node_hostname + self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"] + self._last_request_failed: bool = False + + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + """ + Calculate the reward based on current simulation state, and the recent agent action. + + When the green agent requests to execute the database client application, and that request fails, this reward + component will keep track of that information. In that case, it doesn't matter whether the last successful + request returned was able to connect to the database server, because there has been an unsuccessful request + since. + """ + if last_action_response.request == ["network", "node", self._node, "application", "DatabaseClient", "execute"]: + self._last_request_failed = last_action_response.response.status != "success" + + # if agent couldn't even get as far as sending the request (because for example the node was off), then + # apply a penalty + if self._last_request_failed: + return -1.0 + + # If the last request was actually sent, then check if the connection was established. + db_state = access_from_nested_dict(state, self.location_in_state) + if db_state is NOT_PRESENT_IN_STATE or "last_connection_successful" not in db_state: + _LOGGER.debug(f"Can't calculate reward for {self.__class__.__name__}") + last_connection_successful = db_state["last_connection_successful"] + if last_connection_successful is False: + return -1.0 + elif last_connection_successful is True: + return 1.0 + return 0.0 + + @classmethod + def from_config(cls, config: Dict) -> AbstractReward: + """ + Build the reward component object from config. + + :param config: Configuration dictionary. + :type config: Dict + """ + node_hostname = config.get("node_hostname") + return cls(node_hostname=node_hostname) + + +class SharedReward(AbstractReward): + """Adds another agent's reward to the overall reward.""" + + def __init__(self, agent_name: Optional[str] = None) -> None: + """ + Initialise the shared reward. + + The agent_name is a placeholder value. It starts off as none, but it must be set before this reward can work + correctly. + + :param agent_name: The name whose reward is an input + :type agent_name: Optional[str] + """ + self.agent_name = agent_name + """Agent whose reward to track.""" + + def default_callback(agent_name: str) -> Never: + """ + Default callback to prevent calling this reward until it's properly initialised. + + SharedReward should not be used until the game layer replaces self.callback with a reference to the + function that retrieves the desired agent's reward. Therefore, we define this default callback that raises + an error. + """ + raise RuntimeError("Attempted to calculate SharedReward but it was not initialised properly.") + + self.callback: Callable[[str], float] = default_callback + """Method that retrieves an agent's current reward given the agent's name.""" + + def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: + """Simply access the other agent's reward and return it.""" + return self.callback(self.agent_name) + + @classmethod + def from_config(cls, config: Dict) -> "SharedReward": + """ + Build the SharedReward object from config. + + :param config: Configuration dictionary + :type config: Dict + """ + agent_name = config.get("agent_name") + return cls(agent_name=agent_name) + + class RewardFunction: """Manages the reward function for the agent.""" - __rew_class_identifiers: Dict[str, Type[AbstractReward]] = { + rew_class_identifiers: Dict[str, Type[AbstractReward]] = { "DUMMY": DummyReward, "DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity, "WEB_SERVER_404_PENALTY": WebServer404Penalty, "WEBPAGE_UNAVAILABLE_PENALTY": WebpageUnavailablePenalty, + "GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY": GreenAdminDatabaseUnreachablePenalty, + "SHARED_REWARD": SharedReward, } + """List of reward class identifiers.""" def __init__(self): """Initialise the reward function object.""" @@ -269,7 +387,7 @@ class RewardFunction: """ self.reward_components.append((component, weight)) - def update(self, state: Dict) -> float: + def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float: """Calculate the overall reward for the current state. :param state: The current state of the simulation. @@ -279,7 +397,7 @@ class RewardFunction: for comp_and_weight in self.reward_components: comp = comp_and_weight[0] weight = comp_and_weight[1] - total += weight * comp.calculate(state=state) + total += weight * comp.calculate(state=state, last_action_response=last_action_response) self.current_reward = total return self.current_reward @@ -297,7 +415,7 @@ class RewardFunction: for rew_component_cfg in config["reward_components"]: rew_type = rew_component_cfg["type"] weight = rew_component_cfg.get("weight", 1.0) - rew_class = cls.__rew_class_identifiers[rew_type] + rew_class = cls.rew_class_identifiers[rew_type] rew_instance = rew_class.from_config(config=rew_component_cfg.get("options", {})) new.register_component(component=rew_instance, weight=weight) return new diff --git a/src/primaite/game/agent/scripted_agents.py b/src/primaite/game/agent/scripted_agents.py deleted file mode 100644 index 3748494b..00000000 --- a/src/primaite/game/agent/scripted_agents.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Agents with predefined behaviours.""" -from primaite.game.agent.interface import AbstractScriptedAgent - - -class GreenWebBrowsingAgent(AbstractScriptedAgent): - """Scripted agent which attempts to send web requests to a target node.""" - - raise NotImplementedError - - -class RedDatabaseCorruptingAgent(AbstractScriptedAgent): - """Scripted agent which attempts to corrupt the database of the target node.""" - - raise NotImplementedError diff --git a/src/primaite/game/agent/scripted_agents/__init__.py b/src/primaite/game/agent/scripted_agents/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py new file mode 100644 index 00000000..d3ec19cb --- /dev/null +++ b/src/primaite/game/agent/scripted_agents/data_manipulation_bot.py @@ -0,0 +1,55 @@ +import random +from typing import Dict, Tuple + +from gymnasium.core import ObsType + +from primaite.game.agent.interface import AbstractScriptedAgent + + +class DataManipulationAgent(AbstractScriptedAgent): + """Agent that uses a DataManipulationBot to perform an SQL injection attack.""" + + next_execution_timestep: int = 0 + starting_node_idx: int = 0 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.setup_agent() + + def _set_next_execution_timestep(self, timestep: int) -> None: + """Set the next execution timestep with a configured random variance. + + :param timestep: The timestep to add variance to. + """ + random_timestep_increment = random.randint( + -self.agent_settings.start_settings.variance, self.agent_settings.start_settings.variance + ) + self.next_execution_timestep = timestep + random_timestep_increment + + def get_action(self, obs: ObsType, timestep: int) -> Tuple[str, Dict]: + """Waits until a specific timestep, then attempts to execute its data manipulation application. + + :param obs: Current observation for this agent, not used in DataManipulationAgent + :type obs: ObsType + :param timestep: The current simulation timestep, used for scheduling actions + :type timestep: int + :return: Action formatted in CAOS format + :rtype: Tuple[str, Dict] + """ + if timestep < self.next_execution_timestep: + return "DONOTHING", {} + + self._set_next_execution_timestep(timestep + self.agent_settings.start_settings.frequency) + + return "NODE_APPLICATION_EXECUTE", {"node_id": self.starting_node_idx, "application_id": 0} + + def setup_agent(self) -> None: + """Set the next execution timestep when the episode resets.""" + self._select_start_node() + self._set_next_execution_timestep(self.agent_settings.start_settings.start_step) + + def _select_start_node(self) -> None: + """Set the starting starting node of the agent to be a random node from this agent's action manager.""" + # we are assuming that every node in the node manager has a data manipulation application at idx 0 + num_nodes = len(self.action_manager.node_names) + self.starting_node_idx = random.randint(0, num_nodes - 1) diff --git a/src/primaite/game/agent/scripted_agents/probabilistic_agent.py b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py new file mode 100644 index 00000000..9cddc978 --- /dev/null +++ b/src/primaite/game/agent/scripted_agents/probabilistic_agent.py @@ -0,0 +1,87 @@ +"""Agents with predefined behaviours.""" +from typing import Dict, Optional, Tuple + +import numpy as np +import pydantic +from gymnasium.core import ObsType + +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.interface import AbstractScriptedAgent +from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.rewards import RewardFunction + + +class ProbabilisticAgent(AbstractScriptedAgent): + """Scripted agent which randomly samples its action space with prescribed probabilities for each action.""" + + class Settings(pydantic.BaseModel): + """Config schema for Probabilistic agent settings.""" + + model_config = pydantic.ConfigDict(extra="forbid") + """Strict validation.""" + action_probabilities: Dict[int, float] + """Probability to perform each action in the action map. The sum of probabilities should sum to 1.""" + random_seed: Optional[int] = None + """Random seed. If set, each episode the agent will choose the same random sequence of actions.""" + # TODO: give the option to still set a random seed, but have it vary each episode in a predictable way + # for example if the user sets seed 123, have it be 123 + episode_num, so that each ep it's the next seed. + + @pydantic.field_validator("action_probabilities", mode="after") + @classmethod + def probabilities_sum_to_one(cls, v: Dict[int, float]) -> Dict[int, float]: + """Make sure the probabilities sum to 1.""" + if not abs(sum(v.values()) - 1) < 1e-6: + raise ValueError("Green action probabilities must sum to 1") + return v + + @pydantic.field_validator("action_probabilities", mode="after") + @classmethod + def action_map_covered_correctly(cls, v: Dict[int, float]) -> Dict[int, float]: + """Ensure that the keys of the probability dictionary cover all integers from 0 to N.""" + if not all((i in v) for i in range(len(v))): + raise ValueError( + "Green action probabilities must be defined as a mapping where the keys are consecutive integers " + "from 0 to N." + ) + return v + + def __init__( + self, + agent_name: str, + action_space: Optional[ActionManager], + observation_space: Optional[ObservationManager], + reward_function: Optional[RewardFunction], + settings: Dict = {}, + ) -> None: + # If the action probabilities are not specified, create equal probabilities for all actions + if "action_probabilities" not in settings: + num_actions = len(action_space.action_map) + settings = {"action_probabilities": {i: 1 / num_actions for i in range(num_actions)}} + + # If seed not specified, set it to None so that numpy chooses a random one. + settings.setdefault("random_seed") + + self.settings = ProbabilisticAgent.Settings(**settings) + + self.rng = np.random.default_rng(self.settings.random_seed) + + # convert probabilities from + self.probabilities = np.asarray(list(self.settings.action_probabilities.values())) + + super().__init__(agent_name, action_space, observation_space, reward_function) + + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """ + Sample the action space randomly. + + The probability of each action is given by the corresponding index in ``self.probabilities``. + + :param obs: Current observation for this agent, not used in ProbabilisticAgent + :type obs: ObsType + :param timestep: The current simulation timestep, not used in ProbabilisticAgent + :type timestep: int + :return: Action formatted in CAOS format + :rtype: Tuple[str, Dict] + """ + choice = self.rng.choice(len(self.action_manager.action_map), p=self.probabilities) + return self.action_manager.get_action(choice) diff --git a/src/primaite/game/agent/scripted_agents/random_agent.py b/src/primaite/game/agent/scripted_agents/random_agent.py new file mode 100644 index 00000000..34a4b5ac --- /dev/null +++ b/src/primaite/game/agent/scripted_agents/random_agent.py @@ -0,0 +1,21 @@ +from typing import Dict, Tuple + +from gymnasium.core import ObsType + +from primaite.game.agent.interface import AbstractScriptedAgent + + +class RandomAgent(AbstractScriptedAgent): + """Agent that ignores its observation and acts completely at random.""" + + def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]: + """Sample the action space randomly. + + :param obs: Current observation for this agent, not used in RandomAgent + :type obs: ObsType + :param timestep: The current simulation timestep, not used in RandomAgent + :type timestep: int + :return: Action formatted in CAOS format + :rtype: Tuple[str, Dict] + """ + return self.action_manager.get_action(self.action_manager.space.sample()) diff --git a/src/primaite/game/game.py b/src/primaite/game/game.py index c03bca36..05b76679 100644 --- a/src/primaite/game/game.py +++ b/src/primaite/game/game.py @@ -1,22 +1,25 @@ """PrimAITE game - Encapsulates the simulation and agents.""" from ipaddress import IPv4Address -from typing import Dict, List +from typing import Dict, List, Optional from pydantic import BaseModel, ConfigDict from primaite import getLogger from primaite.game.agent.actions import ActionManager -from primaite.game.agent.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent, RandomAgent -from primaite.game.agent.observations import ObservationManager -from primaite.game.agent.rewards import RewardFunction -from primaite.session.io import SessionIO, SessionIOSettings +from primaite.game.agent.interface import AbstractAgent, AgentSettings, ProxyAgent +from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.rewards import RewardFunction, SharedReward +from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent +from primaite.game.science import graph_has_cycle, topological_sort from primaite.simulator.network.hardware.base import NodeOperatingState from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.network.hardware.nodes.host.host_node import NIC from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall from primaite.simulator.network.hardware.nodes.network.router import Router from primaite.simulator.network.hardware.nodes.network.switch import Switch +from primaite.simulator.network.nmne import set_nmne_config from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.database_client import DatabaseClient @@ -40,6 +43,7 @@ APPLICATION_TYPES_MAPPING = { "DataManipulationBot": DataManipulationBot, "DoSBot": DoSBot, } +"""List of available applications that can be installed on nodes in the PrimAITE Simulation.""" SERVICE_TYPES_MAPPING = { "DNSClient": DNSClient, @@ -51,6 +55,7 @@ SERVICE_TYPES_MAPPING = { "NTPClient": NTPClient, "NTPServer": NTPServer, } +"""List of available services that can be installed on nodes in the PrimAITE Simulation.""" class PrimaiteGameOptions(BaseModel): @@ -63,8 +68,13 @@ class PrimaiteGameOptions(BaseModel): model_config = ConfigDict(extra="forbid") max_episode_length: int = 256 + """Maximum number of episodes for the PrimAITE game.""" ports: List[str] + """A whitelist of available ports in the simulation.""" protocols: List[str] + """A whitelist of available protocols in the simulation.""" + thresholds: Optional[Dict] = {} + """A dict containing the thresholds used for determining what is acceptable during observations.""" class PrimaiteGame: @@ -79,18 +89,15 @@ class PrimaiteGame: self.simulation: Simulation = Simulation() """Simulation object with which the agents will interact.""" - self.agents: List[AbstractAgent] = [] - """List of agents.""" + self.agents: Dict[str, AbstractAgent] = {} + """Mapping from agent name to agent object.""" - self.rl_agents: List[ProxyAgent] = [] - """Subset of agent list including only the reinforcement learning agents.""" + self.rl_agents: Dict[str, ProxyAgent] = {} + """Subset of agents which are intended for reinforcement learning.""" self.step_counter: int = 0 """Current timestep within the episode.""" - self.episode_counter: int = 0 - """Current episode number.""" - self.options: PrimaiteGameOptions """Special options that apply for the entire game.""" @@ -109,6 +116,9 @@ class PrimaiteGame: self.save_step_metadata: bool = False """Whether to save the RL agents' action, environment state, and other data at every single step.""" + self._reward_calculation_order: List[str] = [name for name in self.agents] + """Agent order for reward evaluation, as some rewards can be dependent on other agents' rewards.""" + def step(self): """ Perform one step of the simulation/agent loop. @@ -129,40 +139,49 @@ class PrimaiteGame: """ _LOGGER.debug(f"Stepping. Step counter: {self.step_counter}") - # Get the current state of the simulation - sim_state = self.get_sim_state() - - # Update agents' observations and rewards based on the current state - self.update_agents(sim_state) - + if self.step_counter == 0: + state = self.get_sim_state() + for agent in self.agents.values(): + agent.update_observation(state=state) # Apply all actions to simulation as requests - agent_actions = self.apply_agent_actions() # noqa + self.apply_agent_actions() # Advance timestep self.advance_timestep() + # Get the current state of the simulation + sim_state = self.get_sim_state() + + # Update agents' observations and rewards based on the current state, and the response from the last action + self.update_agents(state=sim_state) + def get_sim_state(self) -> Dict: """Get the current state of the simulation.""" return self.simulation.describe_state() def update_agents(self, state: Dict) -> None: """Update agents' observations and rewards based on the current state.""" - for agent in self.agents: - agent.update_observation(state) - agent.update_reward(state) + for agent_name in self._reward_calculation_order: + agent = self.agents[agent_name] + if self.step_counter > 0: # can't get reward before first action + agent.update_reward(state=state) + agent.update_observation(state=state) # order of this doesn't matter so just use reward order agent.reward_function.total_reward += agent.reward_function.current_reward def apply_agent_actions(self) -> None: """Apply all actions to simulation as requests.""" - agent_actions = {} - for agent in self.agents: + for _, agent in self.agents.items(): obs = agent.observation_manager.current_observation - rew = agent.reward_function.current_reward - action_choice, options = agent.get_action(obs, rew) - agent_actions[agent.agent_name] = (action_choice, options) - request = agent.format_request(action_choice, options) - self.simulation.apply_request(request) - return agent_actions + action_choice, parameters = agent.get_action(obs, timestep=self.step_counter) + request = agent.format_request(action_choice, parameters) + response = self.simulation.apply_request(request) + agent.process_action_response( + timestep=self.step_counter, + action=action_choice, + parameters=parameters, + request=request, + response=response, + ) def advance_timestep(self) -> None: """Advance timestep.""" @@ -178,20 +197,14 @@ class PrimaiteGame: return True return False - def reset(self) -> None: - """Reset the game, this will reset the simulation.""" - self.episode_counter += 1 - self.step_counter = 0 - _LOGGER.debug(f"Resetting primaite game, episode = {self.episode_counter}") - self.simulation.reset_component_for_episode(episode=self.episode_counter) - for agent in self.agents: - agent.reward_function.total_reward = 0.0 - agent.reset_agent_for_episode() - def close(self) -> None: """Close the game, this will close the simulation.""" return NotImplemented + def setup_for_episode(self, episode: int) -> None: + """Perform any final configuration of components to make them ready for the game to start.""" + self.simulation.setup_for_episode(episode=episode) + @classmethod def from_config(cls, cfg: Dict) -> "PrimaiteGame": """Create a PrimaiteGame object from a config dictionary. @@ -209,10 +222,6 @@ class PrimaiteGame: :return: A PrimaiteGame object. :rtype: PrimaiteGame """ - io_settings = cfg.get("io_settings", {}) - _ = SessionIO(SessionIOSettings(**io_settings)) - # Instantiating this ensures that the game saves to the correct output dir even without being part of a session - game = cls() game.options = PrimaiteGameOptions(**cfg["game"]) game.save_step_metadata = cfg.get("io_settings", {}).get("save_step_metadata") or False @@ -221,8 +230,12 @@ class PrimaiteGame: sim = game.simulation net = sim.network - nodes_cfg = cfg["simulation"]["network"]["nodes"] - links_cfg = cfg["simulation"]["network"]["links"] + simulation_config = cfg.get("simulation", {}) + network_config = simulation_config.get("network", {}) + + nodes_cfg = network_config.get("nodes", []) + links_cfg = network_config.get("links", []) + for node_cfg in nodes_cfg: node_ref = node_cfg["ref"] n_type = node_cfg["type"] @@ -230,28 +243,36 @@ class PrimaiteGame: new_node = Computer( hostname=node_cfg["hostname"], ip_address=node_cfg["ip_address"], - subnet_mask=node_cfg["subnet_mask"], + subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), default_gateway=node_cfg["default_gateway"], - dns_server=node_cfg["dns_server"], - operating_state=NodeOperatingState.ON, + dns_server=node_cfg.get("dns_server", None), + operating_state=NodeOperatingState.ON + if not (p := node_cfg.get("operating_state")) + else NodeOperatingState[p.upper()], ) elif n_type == "server": new_node = Server( hostname=node_cfg["hostname"], ip_address=node_cfg["ip_address"], - subnet_mask=node_cfg["subnet_mask"], + subnet_mask=IPv4Address(node_cfg.get("subnet_mask", "255.255.255.0")), default_gateway=node_cfg["default_gateway"], - dns_server=node_cfg.get("dns_server"), - operating_state=NodeOperatingState.ON, + dns_server=node_cfg.get("dns_server", None), + operating_state=NodeOperatingState.ON + if not (p := node_cfg.get("operating_state")) + else NodeOperatingState[p.upper()], ) elif n_type == "switch": new_node = Switch( hostname=node_cfg["hostname"], - num_ports=node_cfg.get("num_ports"), - operating_state=NodeOperatingState.ON, + num_ports=int(node_cfg.get("num_ports", "8")), + operating_state=NodeOperatingState.ON + if not (p := node_cfg.get("operating_state")) + else NodeOperatingState[p.upper()], ) elif n_type == "router": new_node = Router.from_config(node_cfg) + elif n_type == "firewall": + new_node = Firewall.from_config(node_cfg) else: _LOGGER.warning(f"invalid node type {n_type} in config") if "services" in node_cfg: @@ -264,8 +285,13 @@ class PrimaiteGame: new_node.software_manager.install(SERVICE_TYPES_MAPPING[service_type]) new_service = new_node.software_manager.software[service_type] game.ref_map_services[service_ref] = new_service.uuid + + # start the service + new_service.start() else: - _LOGGER.warning(f"service type not found {service_type}") + msg = f"Configuration contains an invalid service type: {service_type}" + _LOGGER.error(msg) + raise ValueError(msg) # service-dependent options if service_type == "DNSClient": if "options" in service_cfg: @@ -281,18 +307,16 @@ class PrimaiteGame: if service_type == "DatabaseService": if "options" in service_cfg: opt = service_cfg["options"] + new_service.password = opt.get("db_password", None) new_service.configure_backup(backup_server=IPv4Address(opt.get("backup_server_ip"))) - new_service.start() if service_type == "FTPServer": if "options" in service_cfg: opt = service_cfg["options"] new_service.server_password = opt.get("server_password") - new_service.start() if service_type == "NTPClient": if "options" in service_cfg: opt = service_cfg["options"] new_service.ntp_server = IPv4Address(opt.get("ntp_server_ip")) - new_service.start() if "applications" in node_cfg: for application_cfg in node_cfg["applications"]: new_application = None @@ -304,7 +328,12 @@ class PrimaiteGame: new_application = new_node.software_manager.software[application_type] game.ref_map_applications[application_ref] = new_application.uuid else: - _LOGGER.warning(f"application type not found {application_type}") + msg = f"Configuration contains an invalid application type: {application_type}" + _LOGGER.error(msg) + raise ValueError(msg) + + # run the application + new_application.run() if application_type == "DataManipulationBot": if "options" in application_cfg: @@ -312,7 +341,7 @@ class PrimaiteGame: new_application.configure( server_ip_address=IPv4Address(opt.get("server_ip")), server_password=opt.get("server_password"), - payload=opt.get("payload"), + payload=opt.get("payload", "DELETE"), port_scan_p_of_success=float(opt.get("port_scan_p_of_success", "0.1")), data_manipulation_p_of_success=float(opt.get("data_manipulation_p_of_success", "0.1")), ) @@ -327,7 +356,6 @@ class PrimaiteGame: if "options" in application_cfg: opt = application_cfg["options"] new_application.target_url = opt.get("target_url") - elif application_type == "DoSBot": if "options" in application_cfg: opt = application_cfg["options"] @@ -344,10 +372,20 @@ class PrimaiteGame: for nic_num, nic_cfg in node_cfg["network_interfaces"].items(): new_node.connect_nic(NIC(ip_address=nic_cfg["ip_address"], subnet_mask=nic_cfg["subnet_mask"])) + # temporarily set to 0 so all nodes are initially on + new_node.start_up_duration = 0 + new_node.shut_down_duration = 0 + net.add_node(new_node) - new_node.power_on() + # run through the power on step if the node is to be turned on at the start + if new_node.operating_state == NodeOperatingState.ON: + new_node.power_on() game.ref_map_nodes[node_ref] = new_node.uuid + # set start up and shut down duration + new_node.start_up_duration = int(node_cfg.get("start_up_duration", 3)) + new_node.shut_down_duration = int(node_cfg.get("shut_down_duration", 3)) + # 2. create links between nodes for link_cfg in links_cfg: node_a = net.nodes[game.ref_map_nodes[link_cfg["endpoint_a_ref"]]] @@ -364,7 +402,7 @@ class PrimaiteGame: game.ref_map_links[link_cfg["ref"]] = new_link.uuid # 3. create agents - agents_cfg = cfg["agents"] + agents_cfg = cfg.get("agents", []) for agent_cfg in agents_cfg: agent_ref = agent_cfg["ref"] # noqa: F841 @@ -382,21 +420,19 @@ class PrimaiteGame: # CREATE REWARD FUNCTION reward_function = RewardFunction.from_config(reward_function_cfg) - # OTHER AGENT SETTINGS - agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) - # CREATE AGENT - if agent_type == "GreenWebBrowsingAgent": + if agent_type == "ProbabilisticAgent": # TODO: implement non-random agents and fix this parsing - new_agent = RandomAgent( + settings = agent_cfg.get("agent_settings", {}) + new_agent = ProbabilisticAgent( agent_name=agent_cfg["ref"], action_space=action_space, observation_space=obs_space, reward_function=reward_function, - agent_settings=agent_settings, + settings=settings, ) - game.agents.append(new_agent) elif agent_type == "ProxyAgent": + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) new_agent = ProxyAgent( agent_name=agent_cfg["ref"], action_space=action_space, @@ -404,9 +440,10 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) - game.rl_agents.append(new_agent) + game.rl_agents[agent_cfg["ref"]] = new_agent elif agent_type == "RedDatabaseCorruptingAgent": + agent_settings = AgentSettings.from_config(agent_cfg.get("agent_settings")) + new_agent = DataManipulationAgent( agent_name=agent_cfg["ref"], action_space=action_space, @@ -414,10 +451,55 @@ class PrimaiteGame: reward_function=reward_function, agent_settings=agent_settings, ) - game.agents.append(new_agent) else: - _LOGGER.warning(f"agent type {agent_type} not found") + msg = f"Configuration error: {agent_type} is not a valid agent type." + _LOGGER.error(msg) + raise ValueError(msg) + game.agents[agent_cfg["ref"]] = new_agent - game.simulation.set_original_state() + # Validate that if any agents are sharing rewards, they aren't forming an infinite loop. + game.setup_reward_sharing() + + # Set the NMNE capture config + set_nmne_config(network_config.get("nmne_config", {})) + game.update_agents(game.get_sim_state()) return game + + def setup_reward_sharing(self): + """Do necessary setup to enable reward sharing between agents. + + This method ensures that there are no cycles in the reward sharing. A cycle would be for example if agent_1 + depends on agent_2 and agent_2 depends on agent_1. It would cause an infinite loop. + + Also, SharedReward requires us to pass it a callback method that will provide the reward of the agent who is + sharing their reward. This callback is provided by this setup method. + + Finally, this method sorts the agents in order in which rewards will be evaluated to make sure that any rewards + that rely on the value of another reward are evaluated later. + + :raises RuntimeError: If the reward sharing is specified with a cyclic dependency. + """ + # construct dependency graph in the reward sharing between agents. + graph = {} + for name, agent in self.agents.items(): + graph[name] = set() + for comp, weight in agent.reward_function.reward_components: + if isinstance(comp, SharedReward): + comp: SharedReward + graph[name].add(comp.agent_name) + + # while constructing the graph, we might as well set up the reward sharing itself. + comp.callback = lambda agent_name: self.agents[agent_name].reward_function.current_reward + + # make sure the graph is acyclic. Otherwise we will enter an infinite loop of reward sharing. + if graph_has_cycle(graph): + raise RuntimeError( + ( + "Detected cycle in agent reward sharing. Check the agent reward function ", + "configuration: reward sharing can only go one way.", + ) + ) + + # sort the agents so the rewards that depend on other rewards are always evaluated later + self._reward_calculation_order = topological_sort(graph) diff --git a/src/primaite/game/science.py b/src/primaite/game/science.py index 19a86237..908b326f 100644 --- a/src/primaite/game/science.py +++ b/src/primaite/game/science.py @@ -1,4 +1,5 @@ from random import random +from typing import Any, Iterable, Mapping def simulate_trial(p_of_success: float) -> bool: @@ -14,3 +15,80 @@ def simulate_trial(p_of_success: float) -> bool: :returns: True if the trial is successful (with probability 'p_of_success'); otherwise, False. """ return random() < p_of_success + + +def graph_has_cycle(graph: Mapping[Any, Iterable[Any]]) -> bool: + """Detect cycles in a directed graph. + + Provide the graph as a dictionary that describes which nodes are linked. For example: + {0: {1,2}, 1:{2,3}, 3:{0}} here there's a cycle 0 -> 1 -> 3 -> 0 + {'a': ('b','c'), c:('b')} here there is no cycle + + :param graph: a mapping from node to a set of nodes to which it is connected. + :type graph: Mapping[Any, Iterable[Any]] + :return: Whether the graph has any cycles + :rtype: bool + """ + visited = set() + currently_visiting = set() + + def depth_first_search(node: Any) -> bool: + """Perform depth-first search (DFS) traversal to detect cycles starting from a given node.""" + if node in currently_visiting: + return True # Cycle detected + if node in visited: + return False # Already visited, no need to explore further + + visited.add(node) + currently_visiting.add(node) + + for neighbour in graph.get(node, []): + if depth_first_search(neighbour): + return True # Cycle detected + + currently_visiting.remove(node) + return False + + # Start DFS traversal from each node + for node in graph: + if depth_first_search(node): + return True # Cycle detected + + return False # No cycles found + + +def topological_sort(graph: Mapping[Any, Iterable[Any]]) -> Iterable[Any]: + """ + Perform topological sorting on a directed graph. + + This guarantees that if there's a directed edge from node A to node B, then A appears before B. + + :param graph: A dictionary representing the directed graph, where keys are node identifiers + and values are lists of outgoing edges from each node. + :type graph: dict[int, list[Any]] + + :return: A topologically sorted list of node identifiers. + :rtype: list[Any] + """ + visited: set[Any] = set() + stack: list[Any] = [] + + def dfs(node: Any) -> None: + """ + Depth-first search traversal to visit nodes and their neighbors. + + :param node: The current node to visit. + :type node: Any + """ + if node in visited: + return + visited.add(node) + for neighbour in graph.get(node, []): + dfs(neighbour) + stack.append(node) + + # Perform DFS traversal from each node + for node in graph: + dfs(node) + + return stack diff --git a/src/primaite/interface/__init__.py b/src/primaite/interface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/primaite/interface/request.py b/src/primaite/interface/request.py new file mode 100644 index 00000000..bc076599 --- /dev/null +++ b/src/primaite/interface/request.py @@ -0,0 +1,46 @@ +from typing import Dict, ForwardRef, List, Literal, Union + +from pydantic import BaseModel, ConfigDict, StrictBool, validate_call + +RequestFormat = List[Union[str, int, float]] + +RequestResponse = ForwardRef("RequestResponse") +"""This makes it possible to type-hint RequestResponse.from_bool return type.""" + + +class RequestResponse(BaseModel): + """Schema for generic request responses.""" + + model_config = ConfigDict(extra="forbid", strict=True) + """Cannot have extra fields in the response. Anything custom goes into the data field.""" + + status: Literal["pending", "success", "failure", "unreachable"] = "pending" + """ + What is the current status of the request: + - pending - the request has not been received yet, or it has been received but it's still being processed. + - success - the request has been received and executed successfully. + - failure - the request has been received and attempted, but execution failed. + - unreachable - the request could not reach it's intended target, either because it doesn't exist or the target + is off. + """ + + data: Dict = {} + """Catch-all place to provide any additional data that was generated as a response to the request.""" + # TODO: currently, status and data have default values, because I don't want to interrupt existing functionality too + # much. However, in the future we might consider making them mandatory. + + @classmethod + @validate_call + def from_bool(cls, status_bool: StrictBool) -> RequestResponse: + """ + Construct a basic request response from a boolean. + + True maps to a success status. False maps to a failure status. + + :param status_bool: Whether to create a successful response + :type status_bool: bool + """ + if status_bool is True: + return cls(status="success", data={}) + elif status_bool is False: + return cls(status="failure", data={}) diff --git a/src/primaite/main.py b/src/primaite/main.py index b63227a7..053ed65b 100644 --- a/src/primaite/main.py +++ b/src/primaite/main.py @@ -5,7 +5,7 @@ from pathlib import Path from typing import Optional, Union from primaite import getLogger -from primaite.config.load import example_config_path, load +from primaite.config.load import data_manipulation_config_path, load from primaite.session.session import PrimaiteSession # from primaite.primaite_session import PrimaiteSession @@ -42,6 +42,6 @@ if __name__ == "__main__": args = parser.parse_args() if not args.config: - args.config = example_config_path() + args.config = data_manipulation_config_path() run(args.config) diff --git a/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb new file mode 100644 index 00000000..56e9bf5a --- /dev/null +++ b/src/primaite/notebooks/Data-Manipulation-Customising-Red-Agent.ipynb @@ -0,0 +1,447 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Customising red agents\n", + "\n", + "This notebook will go over some examples of how red agent behaviour can be varied by changing its configuration parameters.\n", + "\n", + "First, let's load the standard Data Manipulation config file, and see what the red agent does.\n", + "\n", + "*(For a full explanation of the Data Manipulation scenario, check out the notebook `Data-Manipulation-E2E-Demonstration.ipynb`)*" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Imports\n", + "\n", + "from primaite.config.load import data_manipulation_config_path\n", + "from primaite.game.agent.interface import AgentActionHistoryItem\n", + "from primaite.session.environment import PrimaiteGymEnv\n", + "import yaml\n", + "from pprint import pprint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def make_cfg_have_flat_obs(cfg):\n", + " for agent in cfg['agents']:\n", + " if agent['type'] == \"ProxyAgent\":\n", + " agent['agent_settings']['flatten_obs'] = False" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + " make_cfg_have_flat_obs(cfg)\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", + "obs, info = env.reset()\n", + "print('env created successfully')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def friendly_output_red_action(info):\n", + " # parse the info dict form step output and write out what the red agent is doing\n", + " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_action = red_info.action\n", + " if red_action == 'DONOTHING':\n", + " red_str = 'DO NOTHING'\n", + " elif red_action == 'NODE_APPLICATION_EXECUTE':\n", + " client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n", + " red_str = f\"ATTACK from {client}\"\n", + " return red_str" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default, the red agent can start on client 1 or client 2. It starts its attack on a random step between 20 and 30, and it repeats its attack every 15-25 steps.\n", + "\n", + "It also has a 20% chance to fail to perform the port scan, and a 20% chance to fail launching the SQL attack. However it will continue where it left off after a failed step. I.e. if lucky, it can perform the port scan and SQL attack on the first try. If the port scan works, but the sql attack fails the first time it tries to attack, the next time it will not need to port scan again, it can go straight to trying to use SQL attack again." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for step in range(35):\n", + " step_num = env.game.step_counter\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " red = friendly_output_red_action(info)\n", + " print(f\"step: {step_num:3}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since the agent does nothing most of the time, let's only print the steps where it performs an attack." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.reset()\n", + "for step in range(100):\n", + " step_num = env.game.step_counter\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " red = friendly_output_red_action(info)\n", + " if red.startswith(\"ATTACK\"):\n", + " print(f\"step: {step_num:3}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Red Configuration\n", + "\n", + "There are two important parts of the YAML config for varying red agent behaviour.\n", + "\n", + "### Red agent settings\n", + "Here is an annotated config for the red agent in the data manipulation scenario.\n", + "```yaml\n", + " - ref: data_manipulation_attacker # name of agent\n", + " team: RED # not used, just for human reference\n", + " type: RedDatabaseCorruptingAgent # type of agent - this lets primaite know which agent class to use\n", + "\n", + " # Since the agent does not need to react to what is happening in the environment, the observation space is empty.\n", + " observation_space:\n", + " type: UC2RedObservation\n", + " options:\n", + " nodes: {}\n", + "\n", + " action_space:\n", + "\n", + " # The agent has two action choices, either do nothing, or execute a pre-scripted attack by using \n", + " action_list:\n", + " - type: DONOTHING\n", + " - type: NODE_APPLICATION_EXECUTE\n", + "\n", + " # The agent has access to the DataManipulationBoth on clients 1 and 2.\n", + " options:\n", + " nodes:\n", + " - node_name: client_1 # The network should have a node called client_1\n", + " applications:\n", + " - application_name: DataManipulationBot # The node client_1 should have DataManipulationBot configured on it\n", + " - node_name: client_2 # The network should have a node called client_2\n", + " applications:\n", + " - application_name: DataManipulationBot # The node client_2 should have DataManipulationBot configured on it\n", + "\n", + " # not important\n", + " max_folders_per_node: 1\n", + " max_files_per_folder: 1\n", + " max_services_per_node: 1\n", + "\n", + " # red agent does not need a reward function\n", + " reward_function:\n", + " reward_components:\n", + " - type: DUMMY\n", + "\n", + " # These actions are passed to the RedDatabaseCorruptingAgent init method, they dictate the schedule of attacks\n", + " agent_settings:\n", + " start_settings:\n", + " start_step: 25 # first attack at step 25\n", + " frequency: 20 # attacks will happen every 20 steps (on average)\n", + " variance: 5 # the timing of attacks will vary by up to 5 steps earlier or later\n", + "```\n", + "\n", + "### Malicious application settings\n", + "The red agent uses an application called `DataManipulationBot` which leverages a node's `DatabaseClient` to send a malicious SQL query to the database server. Here's an annotated example of how this is configured in the yaml *(with impertinent config items omitted)*:\n", + "```yaml\n", + "simulation:\n", + " network:\n", + " nodes:\n", + " - ref: client_1\n", + " hostname: client_1\n", + " type: computer\n", + " ip_address: 192.168.10.21\n", + " subnet_mask: 255.255.255.0\n", + " default_gateway: 192.168.10.1\n", + " \n", + " # \n", + " applications:\n", + " - ref: data_manipulation_bot\n", + " type: DataManipulationBot\n", + " options:\n", + " port_scan_p_of_success: 0.8 # Probability that port scan is successful\n", + " data_manipulation_p_of_success: 0.8 # Probability that SQL attack is successful\n", + " payload: \"DELETE\" # The SQL query which causes the attack (this has to be DELETE)\n", + " server_ip: 192.168.1.14 # IP address of server hosting the database\n", + " - ref: client_1_database_client\n", + " type: DatabaseClient # Database client must be installed in order for DataManipulationBot to function\n", + " options:\n", + " db_server_ip: 192.168.1.14 # IP address of server hosting the database\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Editing red agent settings\n", + "\n", + "### Removing randomness from attack timing\n", + "\n", + "We can make the attacks happen at completely predictable intervals if we edit the red agent's settings to set variance to 0." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "change = yaml.safe_load(\"\"\"\n", + "start_settings:\n", + " start_step: 25\n", + " frequency: 20\n", + " variance: 0\n", + "\"\"\")\n", + "\n", + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + " for agent in cfg['agents']:\n", + " if agent['ref'] == \"data_manipulation_attacker\":\n", + " agent['agent_settings'] = change\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", + "env.reset()\n", + "for step in range(100):\n", + " step_num = env.game.step_counter\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " red = friendly_output_red_action(info)\n", + " if red.startswith(\"ATTACK\"):\n", + " print(f\"step: {step_num:3}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Making the start node always the same\n", + "\n", + "Normally, the agent randomly chooses between the nodes in its action space to send attacks from:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Open the config without changing anything\n", + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", + "env.reset()\n", + "for ep in range(12):\n", + " env.reset()\n", + " for step in range(31):\n", + " step_num = env.game.step_counter\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " red = friendly_output_red_action(info)\n", + " if red.startswith(\"ATTACK\"):\n", + " print(f\"Episode: {ep:2}, step: {step_num:3}, Red action: {friendly_output_red_action(info)}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can make the agent always start on a node of our choice letting that be the only node in the agent's action space." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "change = yaml.safe_load(\"\"\"\n", + "action_space:\n", + " action_list:\n", + " - type: DONOTHING\n", + " - type: NODE_APPLICATION_EXECUTE\n", + " options:\n", + " nodes:\n", + " - node_name: client_1\n", + " applications:\n", + " - application_name: DataManipulationBot\n", + " max_folders_per_node: 1\n", + " max_files_per_folder: 1\n", + " max_services_per_node: 1\n", + "\"\"\")\n", + "\n", + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + " for agent in cfg['agents']:\n", + " if agent['ref'] == \"data_manipulation_attacker\":\n", + " agent.update(change)\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", + "env.reset()\n", + "for ep in range(12):\n", + " env.reset()\n", + " for step in range(31):\n", + " step_num = env.game.step_counter\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " red = friendly_output_red_action(info)\n", + " if red.startswith(\"ATTACK\"):\n", + " print(f\"Episode: {ep:2}, step: {step_num:3}, Red action: {friendly_output_red_action(info)}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Make the attack less likely to succeed.\n", + "\n", + "We can change the success probabilities within the data manipulation bot application. When the attack succeeds, the reward goes down.\n", + "\n", + "Setting the probabilities to 1.0 means the attack always succeeds - the reward will always drop\n", + "\n", + "Setting the probabilities to 0.0 means the attack always fails - the reward will never drop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Make attack always succeed.\n", + "change = yaml.safe_load(\"\"\"\n", + " applications:\n", + " - ref: data_manipulation_bot\n", + " type: DataManipulationBot\n", + " options:\n", + " port_scan_p_of_success: 1.0\n", + " data_manipulation_p_of_success: 1.0\n", + " payload: \"DELETE\"\n", + " server_ip: 192.168.1.14\n", + " - ref: client_1_web_browser\n", + " type: WebBrowser\n", + " options:\n", + " target_url: http://arcd.com/users/\n", + " - ref: client_1_database_client\n", + " type: DatabaseClient\n", + " options:\n", + " db_server_ip: 192.168.1.14\n", + "\"\"\")\n", + "\n", + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + " cfg['simulation']['network']\n", + " for node in cfg['simulation']['network']['nodes']:\n", + " if node['ref'] in ['client_1', 'client_2']:\n", + " node['applications'] = change['applications']\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", + "env.reset()\n", + "for ep in range(5):\n", + " env.reset()\n", + " for step in range(36):\n", + " step_num = env.game.step_counter\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " red = friendly_output_red_action(info)\n", + " if step_num == 35:\n", + " print(f\"Episode: {ep:2}, step: {step_num:3}, Reward: {reward:.2f}\" )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Make attack always fail.\n", + "change = yaml.safe_load(\"\"\"\n", + " applications:\n", + " - ref: data_manipulation_bot\n", + " type: DataManipulationBot\n", + " options:\n", + " port_scan_p_of_success: 0.0\n", + " data_manipulation_p_of_success: 0.0\n", + " payload: \"DELETE\"\n", + " server_ip: 192.168.1.14\n", + " - ref: client_1_web_browser\n", + " type: WebBrowser\n", + " options:\n", + " target_url: http://arcd.com/users/\n", + " - ref: client_1_database_client\n", + " type: DatabaseClient\n", + " options:\n", + " db_server_ip: 192.168.1.14\n", + "\"\"\")\n", + "\n", + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n", + " cfg['simulation']['network']\n", + " for node in cfg['simulation']['network']['nodes']:\n", + " if node['ref'] in ['client_1', 'client_2']:\n", + " node['applications'] = change['applications']\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", + "env.reset()\n", + "for ep in range(5):\n", + " env.reset()\n", + " for step in range(36):\n", + " step_num = env.game.step_counter\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + " red = friendly_output_red_action(info)\n", + " if step_num == 35:\n", + " print(f\"Episode: {ep:2}, step: {step_num:3}, Reward: {reward:.2f}\" )" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/primaite/notebooks/uc2_demo.ipynb b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb similarity index 64% rename from src/primaite/notebooks/uc2_demo.ipynb rename to src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb index 7af1e605..7ec58b2c 100644 --- a/src/primaite/notebooks/uc2_demo.ipynb +++ b/src/primaite/notebooks/Data-Manipulation-E2E-Demonstration.ipynb @@ -13,7 +13,7 @@ "source": [ "## Scenario\n", "\n", - "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database.\n", + "The network consists of an office subnet and a server subnet. Clients in the office access a website which fetches data from a database. Occasionally, admins need to access the database directly from the clients.\n", "\n", "[](_package_data/uc2_network.png)\n", "\n", @@ -46,7 +46,9 @@ "source": [ "## Green agent\n", "\n", - "There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available." + "There are green agents logged onto client 1 and client 2. They use the web browser to navigate to `http://arcd.com/users`. The web server replies with a status code 200 if the data is available on the database or 404 if not available.\n", + "\n", + "Sometimes, the green agents send a request directly to the database to check that it is reachable." ] }, { @@ -55,7 +57,7 @@ "source": [ "## Red agent\n", "\n", - "The red agent waits a bit then sends a DELETE query to the database from client 1. If the delete is successful, the database file is flagged as compromised to signal that data is not available.\n", + "At the start of every episode, the red agent randomly chooses either client 1 or client 2 to login to. It waits a bit then sends a DELETE query to the database from its chosen client. If the delete is successful, the database file is flagged as compromised to signal that data is not available.\n", "\n", "[](_package_data/uc2_attack.png)\n", "\n", @@ -68,7 +70,9 @@ "source": [ "## Blue agent\n", "\n", - "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking client 1 from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router." + "The blue agent can view the entire network, but the health statuses of components are not updated until a scan is performed. The blue agent should restore the database file from backup after it was compromised. It can also prevent further attacks by blocking the red agent client from sending the malicious SQL query to the database server. This can be done by implementing an ACL rule on the router.\n", + "\n", + "However, these rules will also impact greens' ability to check the database connection. The blue agent should only block the infected client, it should let the other client connect freely. Once the attack has begun, automated traffic monitoring will detect it as suspicious network traffic. The blue agent's observation space will show this as an increase in the number of malicious network events (NMNE) on one of the network interfaces. To achieve optimal reward, the agent should only block the client which has the non-zero outbound NMNE." ] }, { @@ -84,7 +88,7 @@ "source": [ "## Scripted agents:\n", "### Red\n", - "The red agent sits on client 1 and uses an application called DataManipulationBot whose sole purpose is to send a DELETE query to the database.\n", + "The red agent sits on a client and uses an application called DataManipulationBot whose sole purpose is to send a DELETE query to the database.\n", "The red agent can choose one of two action each timestep:\n", "1. do nothing\n", "2. execute the data manipulation application\n", @@ -92,6 +96,7 @@ "- start time\n", "- frequency\n", "- variance\n", + "\n", "Attacks start at a random timestep between (start_time - variance) and (start_time + variance). After each attack, another is attempted after a random delay between (frequency - variance) and (frequency + variance) timesteps.\n", "\n", "The data manipulation app itself has an element of randomness because the attack has a probability of success. The default is 0.8 to succeed with the port scan step and 0.8 to succeed with the attack itself.\n", @@ -100,9 +105,11 @@ "The red agent does not use information about the state of the network to decide its action.\n", "\n", "### Green\n", - "The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, meaning it will request webpage with a 50% probability, and do nothing with a 50% probability.\n", + "The green agents use the web browser application to send requests to the web server. The schedule of each green agent is currently random, it will do nothing 30% of the time, send a web request 60% of the time, and send a db status check 10% of the time.\n", "\n", - "When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender." + "When a green agent is blocked from accessing the data through the webpage, this incurs a negative reward to the RL defender.\n", + "\n", + "Also, when the green agent is blocked from checking the database status, it causes a small negative reward." ] }, { @@ -129,6 +136,9 @@ " - NETWORK_INTERFACES\n", " - \n", " - nic_status\n", + " - nmne\n", + " - inbound\n", + " - outbound\n", " - operating_status\n", "- LINKS\n", " - \n", @@ -219,6 +229,14 @@ "|1|ENABLED|\n", "|2|DISABLED|\n", "\n", + "NMNE (number of malicious network events) means, for inbound or outbound traffic, means:\n", + "|value|NMNEs|\n", + "|--|--|\n", + "|0|None|\n", + "|1|1 - 5|\n", + "|2|6 - 10|\n", + "|3|More than 10|\n", + "\n", "Link load has the following meaning:\n", "|load|percent utilisation|\n", "|--|--|\n", @@ -289,11 +307,17 @@ "- `1`: Scan the web service - this refreshes the health status in the observation space\n", "- `9`: Scan the database file - this refreshes the health status of the database file\n", "- `13`: Patch the database service - This triggers the database to restore data from the backup server\n", - "- `19`: Shut down client 1\n", - "- `22`: Block outgoing traffic from client 1\n", - "- `26`: Block TCP traffic from client 1 to the database node\n", - "- `28-37`: Remove ACL rules 1-10\n", - "- `42`: Disconnect client 1 from the network\n", + "- `39`: Shut down client 1\n", + "- `40`: Start up client 1\n", + "- `46`: Block outgoing traffic from client 1\n", + "- `47`: Block outgoing traffic from client 2\n", + "- `50`: Block TCP traffic from client 1 to the database node\n", + "- `51`: Block TCP traffic from client 2 to the database node\n", + "- `52-61`: Remove ACL rules 1-10\n", + "- `66`: Disconnect client 1 from the network\n", + "- `67`: Reconnect client 1 to the network\n", + "- `68`: Disconnect client 2 from the network\n", + "- `69`: Reconnect client 2 to the network\n", "\n", "The other actions will either have no effect or will negatively impact the network, so the blue agent should avoid taking them." ] @@ -304,9 +328,10 @@ "source": [ "## Reward Function\n", "\n", - "The blue agent's reward is calculated using two measures:\n", + "The blue agent's reward is calculated using these measures:\n", "1. Whether the database file is in a good state (+1 for good, -1 for corrupted, 0 for any other state)\n", "2. Whether each green agents' most recent webpage request was successful (+1 for a `200` return code, -1 for a `404` return code and 0 otherwise).\n", + "3. Whether each green agents' most recent DB status check was successful (+1 for a successful connection, -1 for no connection).\n", "\n", "The file status reward and the two green-agent-related rewards are averaged to get a total step reward.\n" ] @@ -346,9 +371,9 @@ "outputs": [], "source": [ "# Imports\n", - "from primaite.config.load import example_config_path\n", + "from primaite.config.load import data_manipulation_config_path\n", "from primaite.session.environment import PrimaiteGymEnv\n", - "from primaite.game.game import PrimaiteGame\n", + "from primaite.game.agent.interface import AgentActionHistoryItem\n", "import yaml\n", "from pprint import pprint\n" ] @@ -365,21 +390,21 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "# create the env\n", - "with open(example_config_path(), 'r') as f:\n", + "with open(data_manipulation_config_path(), 'r') as f:\n", " cfg = yaml.safe_load(f)\n", " # set success probability to 1.0 to avoid rerunning cells.\n", " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", + " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['data_manipulation_p_of_success'] = 1.0\n", " cfg['simulation']['network']['nodes'][8]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", - "game = PrimaiteGame.from_config(cfg)\n", - "env = PrimaiteGymEnv(game = game)\n", - "# Don't flatten obs as we are not training an agent and we wish to see the dict-formatted observations\n", - "env.agent.flatten_obs = False\n", + " cfg['simulation']['network']['nodes'][9]['applications'][0]['options']['port_scan_p_of_success'] = 1.0\n", + " # don't flatten observations so that we can see what is going on\n", + " cfg['agents'][3]['agent_settings']['flatten_obs'] = False\n", + "\n", + "env = PrimaiteGymEnv(game_config = cfg)\n", "obs, info = env.reset()\n", "print('env created successfully')\n", "pprint(obs)" @@ -389,35 +414,49 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will go from 1.0 to 0.0, and to -1.0 when the green agent tries to access the webpage." + "The red agent will start attacking at some point between step 20 and 30. When this happens, the reward will drop immediately, then drop to -0.8 when green agents try to access the webpage." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ - "for step in range(32):\n", + "def friendly_output_red_action(info):\n", + " # parse the info dict form step output and write out what the red agent is doing\n", + " red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n", + " red_action = red_info.action\n", + " if red_action == 'DONOTHING':\n", + " red_str = 'DO NOTHING'\n", + " elif red_action == 'NODE_APPLICATION_EXECUTE':\n", + " client = \"client 1\" if red_info.parameters['node_id'] == 0 else \"client 2\"\n", + " red_str = f\"ATTACK from {client}\"\n", + " return red_str" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for step in range(35):\n", " obs, reward, terminated, truncated, info = env.step(0)\n", - " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {friendly_output_red_action(info)}, Blue reward:{reward:.2f}\" )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now the reward is -1, let's have a look at blue agent's observation." + "Now the reward is -0.8, let's have a look at blue agent's observation." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "pprint(obs['NODES'])" @@ -433,9 +472,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(9) # scan database file\n", @@ -451,6 +488,13 @@ "File 1 in folder 1 on node 3 has `health_status = 2`, indicating that the database file is compromised." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Also, the NMNE outbound of either client 1 (node 6) or client 2 (node 7) increased from 0 to 1, but only right after the red attack, so we probably cannot see it now." + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -461,16 +505,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "obs, reward, terminated, truncated, info = env.step(13) # patch the database\n", "print(f\"step: {env.game.step_counter}\")\n", - "print(f\"Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", + "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_1_green_user'].action}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user'].action}\" )\n", "print(f\"Blue reward:{reward}\" )" ] }, @@ -480,65 +522,68 @@ "source": [ "The patching takes two steps, so the reward hasn't changed yet. Let's do nothing for another timestep, the reward should improve.\n", "\n", - "The reward will be 0 as soon as the file finishes restoring. Then, the reward will increase to 1 when the green agent makes a request. (Because the webapp access part of the reward does not update until a successful request is made.)\n", + "The reward will increase slightly as soon as the file finishes restoring. Then, the reward will increase to 1 when both green agents make successful requests.\n", "\n", - "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE`, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." + "Run the following cell until the green action is `NODE_APPLICATION_EXECUTE` for application 0, then the reward should become 1. If you run it enough times, another red attack will happen and the reward will drop again." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ - "obs, reward, terminated, truncated, info = env.step(0) # patch the database\n", + "obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", "print(f\"step: {env.game.step_counter}\")\n", - "print(f\"Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_2_green_user'][0]}\" )\n", - "print(f\"Green action: {info['agent_actions']['client_1_green_user'][0]}\" )\n", - "print(f\"Blue reward:{reward}\" )" + "print(f\"Red action: {info['agent_actions']['data_manipulation_attacker'].action}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_2_green_user']}\" )\n", + "print(f\"Green action: {info['agent_actions']['client_1_green_user']}\" )\n", + "print(f\"Blue reward:{reward:.2f}\" )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The blue agent can prevent attacks by implementing an ACL rule to stop client_1 from sending POSTGRES traffic to the database. (Let's also patch the database file to get the reward back up.)" + "The blue agent can prevent attacks by implementing an ACL rule to stop client_1 or client_2 from sending POSTGRES traffic to the database. (Let's also patch the database file to get the reward back up.)\n", + "\n", + "Let's block both clients from communicating directly with the database." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "env.step(13) # Patch the database\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", "\n", - "env.step(26) # Block client 1\n", - "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )\n", + "env.step(50) # Block client 1\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", "\n", - "for step in range(30):\n", + "env.step(51) # Block client 2\n", + "print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", + "\n", + "while abs(reward - 0.8) > 1e-5:\n", " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", - " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['client_1_data_manipulation_red_bot'][0]}, Blue reward:{reward}\" )" + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )\n", + " if env.game.step_counter > 10000:\n", + " break # make sure there's no infinite loop if something went wrong" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Now, even though the red agent executes an attack, the reward stays at 1.0" + "Now, even though the red agent executes an attack, the reward will stay at 0.8." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Let's also have a look at the ACL observation to verify our new ACL rule at position 5." + "Let's also have a look at the ACL observation to verify our new ACL rule at positions 5 and 6." ] }, { @@ -549,11 +594,92 @@ "source": [ "obs['ACL']" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can slightly increase the reward by unblocking the client which isn't being used by the attacker. If node 6 has outbound NMNEs, let's unblock client 2, and if node 7 has outbound NMNEs, let's unblock client 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.step(58) # Remove the ACL rule that blocks client 1\n", + "env.step(57) # Remove the ACL rule that blocks client 2\n", + "\n", + "tries = 0\n", + "while True:\n", + " tries += 1\n", + " obs, reward, terminated, truncated, info = env.step(0)\n", + "\n", + " if obs['NODES'][6]['NICS'][1]['NMNE']['outbound'] == 1:\n", + " # client 1 has NMNEs, let's block it\n", + " obs, reward, terminated, truncated, info = env.step(50) # block client 1\n", + " print(\"blocking client 1\")\n", + " break\n", + " elif obs['NODES'][7]['NICS'][1]['NMNE']['outbound'] == 1:\n", + " # client 2 has NMNEs, so let's block it\n", + " obs, reward, terminated, truncated, info = env.step(51) # block client 2\n", + " print(\"blocking client 2\")\n", + " break\n", + " if tries>100:\n", + " print(\"Error: NMNE never increased\")\n", + " break\n", + "\n", + "env.step(13) # Patch the database\n", + "print()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Now, the reward will eventually increase to 0.9, even after red agent attempts subsequent attacks." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "for step in range(40):\n", + " obs, reward, terminated, truncated, info = env.step(0) # do nothing\n", + " print(f\"step: {env.game.step_counter}, Red action: {info['agent_actions']['data_manipulation_attacker'].action}, Blue reward:{reward:.2f}\" )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reset the environment, you can rerun the other cells to verify that the attack works the same every episode. (except the red agent will move between `client_1` and `client_2`.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "venv", "language": "python", "name": "python3" }, @@ -567,9 +693,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.10.12" } }, "nbformat": 4, - "nbformat_minor": 4 + "nbformat_minor": 2 } diff --git a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb similarity index 91% rename from src/primaite/notebooks/training_example_ray_multi_agent.ipynb rename to src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb index 0d4b6d0e..76623697 100644 --- a/src/primaite/notebooks/training_example_ray_multi_agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLIB-MARL-System.ipynb @@ -35,7 +35,7 @@ "\n", "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", "# to copy the files to your user data path.\n", - "with open(PRIMAITE_PATHS.user_config_path / 'example_config/example_config_2_rl_agents.yaml', 'r') as f:\n", + "with open(PRIMAITE_PATHS.user_config_path / 'example_config/data_manipulation_marl.yaml', 'r') as f:\n", " cfg = yaml.safe_load(f)\n", "\n", "ray.init(local_mode=True)" @@ -60,7 +60,7 @@ " policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n", " policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n", " )\n", - " .environment(env=PrimaiteRayMARLEnv, env_config={\"cfg\":cfg})#, disable_env_checking=True)\n", + " .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n", " .rollouts(num_rollout_workers=0)\n", " .training(train_batch_size=128)\n", " )\n" @@ -88,6 +88,13 @@ " param_space=config\n", ").fit()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/notebooks/training_example_ray_single_agent.ipynb b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb similarity index 94% rename from src/primaite/notebooks/training_example_ray_single_agent.ipynb rename to src/primaite/notebooks/Training-an-RLLib-Agent.ipynb index ea006ae9..2fe84655 100644 --- a/src/primaite/notebooks/training_example_ray_single_agent.ipynb +++ b/src/primaite/notebooks/Training-an-RLLib-Agent.ipynb @@ -16,7 +16,7 @@ "source": [ "from primaite.game.game import PrimaiteGame\n", "import yaml\n", - "from primaite.config.load import example_config_path\n", + "from primaite.config.load import data_manipulation_config_path\n", "\n", "from primaite.session.environment import PrimaiteRayEnv\n", "from ray.rllib.algorithms import ppo\n", @@ -26,7 +26,7 @@ "\n", "# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n", "# to copy the files to your user data path.\n", - "with open(example_config_path(), 'r') as f:\n", + "with open(data_manipulation_config_path(), 'r') as f:\n", " cfg = yaml.safe_load(f)\n", "\n", "ray.init(local_mode=True)\n" @@ -54,7 +54,7 @@ "metadata": {}, "outputs": [], "source": [ - "env_config = {\"cfg\":cfg}\n", + "env_config = cfg\n", "\n", "config = (\n", " PPOConfig()\n", diff --git a/src/primaite/notebooks/training_example_sb3.ipynb b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb similarity index 81% rename from src/primaite/notebooks/training_example_sb3.ipynb rename to src/primaite/notebooks/Training-an-SB3-Agent.ipynb index e5085c5e..cefcc429 100644 --- a/src/primaite/notebooks/training_example_sb3.ipynb +++ b/src/primaite/notebooks/Training-an-SB3-Agent.ipynb @@ -17,7 +17,7 @@ "metadata": {}, "outputs": [], "source": [ - "from primaite.config.load import example_config_path" + "from primaite.config.load import data_manipulation_config_path" ] }, { @@ -26,10 +26,8 @@ "metadata": {}, "outputs": [], "source": [ - "with open(example_config_path(), 'r') as f:\n", - " cfg = yaml.safe_load(f)\n", - "\n", - "game = PrimaiteGame.from_config(cfg)" + "with open(data_manipulation_config_path(), 'r') as f:\n", + " cfg = yaml.safe_load(f)\n" ] }, { @@ -38,7 +36,7 @@ "metadata": {}, "outputs": [], "source": [ - "gym = PrimaiteGymEnv(game=game)" + "gym = PrimaiteGymEnv(game_config=cfg)" ] }, { @@ -65,7 +63,7 @@ "metadata": {}, "outputs": [], "source": [ - "model.learn(total_timesteps=1000)\n" + "model.learn(total_timesteps=10)\n" ] }, { @@ -76,6 +74,13 @@ "source": [ "model.save(\"deleteme\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index a3831bc1..1795f14b 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,14 +1,19 @@ +import copy import json -from typing import Any, Dict, Final, Optional, SupportsFloat, Tuple +from typing import Any, Dict, Optional, SupportsFloat, Tuple import gymnasium from gymnasium.core import ActType, ObsType from ray.rllib.env.multi_agent_env import MultiAgentEnv +from primaite import getLogger from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.session.io import PrimaiteIO from primaite.simulator import SIM_OUTPUT +_LOGGER = getLogger(__name__) + class PrimaiteGymEnv(gymnasium.Env): """ @@ -18,40 +23,56 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game: PrimaiteGame): + def __init__(self, game_config: Dict): """Initialise the environment.""" super().__init__() - self.game: "PrimaiteGame" = game - self.agent: ProxyAgent = self.game.rl_agents[0] + self.game_config: Dict = game_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) + """Current game.""" + self._agent_name = next(iter(self.game.rl_agents)) + """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" + + self.episode_counter: int = 0 + """Current episode number.""" + + self.io = PrimaiteIO.from_config(game_config.get("io_settings", {})) + """Handles IO for the environment. This produces sys logs, agent logs, etc.""" + + @property + def agent(self) -> ProxyAgent: + """Grab a fresh reference to the agent object because it will be reinstantiated each episode.""" + return self.game.rl_agents[self._agent_name] def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """Perform a step in the environment.""" # make ProxyAgent store the action chosen my the RL policy self.agent.store_action(action) # apply_agent_actions accesses the action we just stored - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() self.game.advance_timestep() state = self.game.get_sim_state() - self.game.update_agents(state) - next_obs = self._get_obs() + next_obs = self._get_obs() # this doesn't update observation, just gets the current observation reward = self.agent.reward_function.current_reward terminated = False truncated = self.game.calculate_truncated() - info = {"agent_actions": agent_actions} # tell us what all the agents did for convenience. + info = { + "agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()} + } # tell us what all the agents did for convenience. if self.game.save_step_metadata: self._write_step_metadata_json(action, state, reward) return next_obs, reward, terminated, truncated, info def _write_step_metadata_json(self, action: int, state: Dict, reward: int): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "action": int(action), "reward": int(reward), @@ -62,13 +83,18 @@ class PrimaiteGymEnv(gymnasium.Env): def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]: """Reset the environment.""" - print( - f"Resetting environment, episode {self.game.episode_counter}, " - f"avg. reward: {self.game.rl_agents[0].reward_function.total_reward}" + _LOGGER.info( + f"Resetting environment, episode {self.episode_counter}, " + f"avg. reward: {self.agent.reward_function.total_reward}" ) - self.game.reset() + if self.io.settings.save_agent_actions: + all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} + self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) + self.game.setup_for_episode(episode=self.episode_counter) + self.episode_counter += 1 state = self.game.get_sim_state() - self.game.update_agents(state) + self.game.update_agents(state=state) next_obs = self._get_obs() info = {} return next_obs, info @@ -88,12 +114,12 @@ class PrimaiteGymEnv(gymnasium.Env): def _get_obs(self) -> ObsType: """Return the current observation.""" - if not self.agent.flatten_obs: - return self.agent.observation_manager.current_observation - else: + if self.agent.flatten_obs: unflat_space = self.agent.observation_manager.space unflat_obs = self.agent.observation_manager.current_observation return gymnasium.spaces.flatten(unflat_space, unflat_obs) + else: + return self.agent.observation_manager.current_observation class PrimaiteRayEnv(gymnasium.Env): @@ -102,12 +128,11 @@ class PrimaiteRayEnv(gymnasium.Env): def __init__(self, env_config: Dict) -> None: """Initialise the environment. - :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` - which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :param env_config: A dictionary containing the environment configuration. + :type env_config: Dict """ - self.env = PrimaiteGymEnv(game=PrimaiteGame.from_config(env_config["cfg"])) - self.env.game.episode_counter -= 1 + self.env = PrimaiteGymEnv(game_config=env_config) + self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -128,13 +153,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): :param env_config: A dictionary containing the environment configuration. It must contain a single key, `game` which is the PrimaiteGame instance. - :type env_config: Dict[str, PrimaiteGame] + :type env_config: Dict """ - self.game: PrimaiteGame = PrimaiteGame.from_config(env_config["cfg"]) + self.game_config: Dict = env_config + """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" + self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) """Reference to the primaite game""" - self.agents: Final[Dict[str, ProxyAgent]] = {agent.agent_name: agent for agent in self.game.rl_agents} - """List of all possible agents in the environment. This list should not change!""" - self._agent_ids = list(self.agents.keys()) + self._agent_ids = list(self.game.rl_agents.keys()) + """Agent ids. This is a list of strings of agent names.""" + self.episode_counter: int = 0 + """Current episode number.""" self.terminateds = set() self.truncateds = set() @@ -147,11 +175,25 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): self.action_space = gymnasium.spaces.Dict( {name: agent.action_manager.space for name, agent in self.agents.items()} ) + + self.io = PrimaiteIO.from_config(env_config.get("io_settings")) + """Handles IO for the environment. This produces sys logs, agent logs, etc.""" + super().__init__() + @property + def agents(self) -> Dict[str, ProxyAgent]: + """Grab a fresh reference to the agents from this episode's game object.""" + return {name: self.game.rl_agents[name] for name in self._agent_ids} + def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]: """Reset the environment.""" - self.game.reset() + if self.io.settings.save_agent_actions: + all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} + self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) + self.game.setup_for_episode(episode=self.episode_counter) + self.episode_counter += 1 state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() @@ -172,7 +214,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): # 1. Perform actions for agent_name, action in actions.items(): self.agents[agent_name].store_action(action) - agent_actions = self.game.apply_agent_actions() + self.game.apply_agent_actions() # 2. Advance timestep self.game.advance_timestep() @@ -186,7 +228,7 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()} terminateds = {name: False for name, _ in self.agents.items()} truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()} - infos = {"agent_actions": agent_actions} + infos = {name: {} for name, _ in self.agents.items()} terminateds["__all__"] = len(self.terminateds) == len(self.agents) truncateds["__all__"] = self.game.calculate_truncated() if self.game.save_step_metadata: @@ -194,13 +236,13 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): return next_obs, rewards, terminateds, truncateds, infos def _write_step_metadata_json(self, actions: Dict, state: Dict, rewards: Dict): - output_dir = SIM_OUTPUT.path / f"episode_{self.game.episode_counter}" / "step_metadata" + output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata" output_dir.mkdir(parents=True, exist_ok=True) path = output_dir / f"step_{self.game.step_counter}.json" data = { - "episode": self.game.episode_counter, + "episode": self.episode_counter, "step": self.game.step_counter, "actions": {agent_name: int(action) for agent_name, action in actions.items()}, "reward": rewards, @@ -212,8 +254,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): def _get_obs(self) -> Dict[str, ObsType]: """Return the current observation.""" obs = {} - for name, agent in self.agents.items(): + for agent_name in self._agent_ids: + agent = self.game.rl_agents[agent_name] unflat_space = agent.observation_manager.space unflat_obs = agent.observation_manager.current_observation - obs[name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) + obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs) return obs diff --git a/src/primaite/session/io.py b/src/primaite/session/io.py index b4b740e9..e57f88ae 100644 --- a/src/primaite/session/io.py +++ b/src/primaite/session/io.py @@ -1,56 +1,53 @@ +import json from datetime import datetime from pathlib import Path -from typing import Optional +from typing import Dict, List, Optional from pydantic import BaseModel, ConfigDict -from primaite import PRIMAITE_PATHS +from primaite import getLogger, PRIMAITE_PATHS from primaite.simulator import SIM_OUTPUT - -class SessionIOSettings(BaseModel): - """Schema for session IO settings.""" - - model_config = ConfigDict(extra="forbid") - - save_final_model: bool = True - """Whether to save the final model right at the end of training.""" - save_checkpoints: bool = False - """Whether to save a checkpoint model every `checkpoint_interval` episodes""" - checkpoint_interval: int = 10 - """How often to save a checkpoint model (if save_checkpoints is True).""" - save_logs: bool = True - """Whether to save logs""" - save_transactions: bool = True - """Whether to save transactions, If true, the session path will have a transactions folder.""" - save_tensorboard_logs: bool = False - """Whether to save tensorboard logs. If true, the session path will have a tensorboard_logs folder.""" - save_step_metadata: bool = False - """Whether to save the RL agents' action, environment state, and other data at every single step.""" - save_pcap_logs: bool = False - """Whether to save PCAP logs.""" - save_sys_logs: bool = False - """Whether to save system logs.""" +_LOGGER = getLogger(__name__) -class SessionIO: +class PrimaiteIO: """ Class for managing session IO. - Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on. + Currently it's handling path generation, but could expand to handle loading, transaction, and so on. """ - def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None: - self.settings: SessionIOSettings = settings + class Settings(BaseModel): + """Config schema for PrimaiteIO object.""" + + model_config = ConfigDict(extra="forbid") + + save_logs: bool = True + """Whether to save logs""" + save_agent_actions: bool = True + """Whether to save a log of all agents' actions every step.""" + save_step_metadata: bool = False + """Whether to save the RL agents' action, environment state, and other data at every single step.""" + save_pcap_logs: bool = False + """Whether to save PCAP logs.""" + save_sys_logs: bool = False + """Whether to save system logs.""" + + def __init__(self, settings: Optional[Settings] = None) -> None: + """ + Init the PrimaiteIO object. + + Note: Instantiating this object creates a new directory for outputs, and sets the global SIM_OUTPUT variable. + It is intended that this object is instantiated when a new environment is created. + """ + self.settings = settings or PrimaiteIO.Settings() self.session_path: Path = self.generate_session_path() # set global SIM_OUTPUT path SIM_OUTPUT.path = self.session_path / "simulation_output" SIM_OUTPUT.save_pcap_logs = self.settings.save_pcap_logs SIM_OUTPUT.save_sys_logs = self.settings.save_sys_logs - # warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's - # possible refactor needed - def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path: """Create a folder for the session and return the path to it.""" if timestamp is None: @@ -68,3 +65,32 @@ class SessionIO: def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path: """Return the path where the checkpoint model will be saved (excluding filename extension).""" return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt" + + def generate_agent_actions_save_path(self, episode: int) -> Path: + """Return the path where agent actions will be saved.""" + return self.session_path / "agent_actions" / f"episode_{episode}.json" + + def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None: + """Take the contents of the agent action log and write it to a file. + + :param episode: Episode number + :type episode: int + """ + data = {} + longest_history = max([len(hist) for hist in agent_actions.values()]) + for i in range(longest_history): + data[i] = {"timestep": i, "episode": episode} + data[i].update({name: acts[i] for name, acts in agent_actions.items() if len(acts) > i}) + + path = self.generate_agent_actions_save_path(episode=episode) + path.parent.mkdir(exist_ok=True, parents=True) + path.touch() + _LOGGER.info(f"Saving agent action log to {path}") + with open(path, "w") as file: + json.dump(data, fp=file, indent=1, default=lambda x: x.model_dump()) + + @classmethod + def from_config(cls, config: Dict) -> "PrimaiteIO": + """Create an instance of PrimaiteIO based on a configuration dict.""" + new = cls(settings=cls.Settings(**config)) + return new diff --git a/src/primaite/session/policy/sb3.py b/src/primaite/session/policy/sb3.py index 254baf4d..6220371d 100644 --- a/src/primaite/session/policy/sb3.py +++ b/src/primaite/session/policy/sb3.py @@ -39,9 +39,9 @@ class SB3Policy(PolicyABC, identifier="SB3"): def learn(self, n_episodes: int, timesteps_per_episode: int) -> None: """Train the agent.""" - if self.session.io_manager.settings.save_checkpoints: + if self.session.save_checkpoints: checkpoint_callback = CheckpointCallback( - save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval, + save_freq=timesteps_per_episode * self.session.checkpoint_interval, save_path=self.session.io_manager.generate_model_save_path("sb3"), name_prefix="sb3_model", ) diff --git a/src/primaite/session/session.py b/src/primaite/session/session.py index 5c663cfd..9c935ae3 100644 --- a/src/primaite/session/session.py +++ b/src/primaite/session/session.py @@ -1,12 +1,12 @@ +# raise DeprecationWarning("This module is deprecated") from enum import Enum from pathlib import Path from typing import Dict, List, Literal, Optional, Union from pydantic import BaseModel, ConfigDict -from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv -from primaite.session.io import SessionIO, SessionIOSettings +from primaite.session.io import PrimaiteIO # from primaite.game.game import PrimaiteGame from primaite.session.policy.policy import PolicyABC @@ -40,7 +40,7 @@ class SessionMode(Enum): class PrimaiteSession: """The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments.""" - def __init__(self, game: PrimaiteGame): + def __init__(self, game_cfg: Dict): """Initialise PrimaiteSession object.""" self.training_options: TrainingOptions """Options specific to agent training.""" @@ -54,12 +54,18 @@ class PrimaiteSession: self.policy: PolicyABC """The reinforcement learning policy.""" - self.io_manager: Optional["SessionIO"] = None + self.io_manager: Optional["PrimaiteIO"] = None """IO manager for the session.""" - self.game: PrimaiteGame = game + self.game_cfg: Dict = game_cfg """Primaite Game object for managing main simulation loop and agents.""" + self.save_checkpoints: bool = False + """Whether to save checkpoints.""" + + self.checkpoint_interval: int = 10 + """If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes.""" + def start_session(self) -> None: """Commence the training/eval session.""" print("Starting Primaite Session") @@ -90,22 +96,21 @@ class PrimaiteSession: def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession": """Create a PrimaiteSession object from a config dictionary.""" # READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS... - io_settings = cfg.get("io_settings", {}) - io_manager = SessionIO(SessionIOSettings(**io_settings)) + io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {})) - game = PrimaiteGame.from_config(cfg) - - sess = cls(game=game) + sess = cls(game_cfg=cfg) sess.io_manager = io_manager sess.training_options = TrainingOptions(**cfg["training_config"]) + sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints") + sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval") # CREATE ENVIRONMENT if sess.training_options.rl_framework == "RLLIB_single_agent": - sess.env = PrimaiteRayEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayEnv(env_config=cfg) elif sess.training_options.rl_framework == "RLLIB_multi_agent": - sess.env = PrimaiteRayMARLEnv(env_config={"cfg": cfg}) + sess.env = PrimaiteRayMARLEnv(env_config=cfg) elif sess.training_options.rl_framework == "SB3": - sess.env = PrimaiteGymEnv(game=game) + sess.env = PrimaiteGymEnv(game_config=cfg) sess.policy = PolicyABC.from_config(sess.training_options, session=sess) if agent_load_path: diff --git a/src/primaite/simulator/__init__.py b/src/primaite/simulator/__init__.py index 97bcd57b..aebd77cf 100644 --- a/src/primaite/simulator/__init__.py +++ b/src/primaite/simulator/__init__.py @@ -12,8 +12,8 @@ class _SimOutput: self._path: Path = ( _PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") ) - self.save_pcap_logs: bool = True - self.save_sys_logs: bool = True + self.save_pcap_logs: bool = False + self.save_sys_logs: bool = False @property def path(self) -> Path: diff --git a/src/primaite/simulator/core.py b/src/primaite/simulator/core.py index 964dac01..6da8a2f8 100644 --- a/src/primaite/simulator/core.py +++ b/src/primaite/simulator/core.py @@ -1,12 +1,13 @@ # flake8: noqa """Core of the PrimAITE Simulator.""" -from abc import ABC, abstractmethod -from typing import Callable, ClassVar, Dict, List, Optional, Union +from abc import abstractmethod +from typing import Callable, Dict, List, Literal, Optional, Union from uuid import uuid4 -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, validate_call from primaite import getLogger +from primaite.interface.request import RequestFormat, RequestResponse _LOGGER = getLogger(__name__) @@ -21,15 +22,15 @@ class RequestPermissionValidator(BaseModel): """ @abstractmethod - def __call__(self, request: List[str], context: Dict) -> bool: - """Use the request and context paramters to decide whether the request should be permitted.""" + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Use the request and context parameters to decide whether the request should be permitted.""" pass class AllowAllValidator(RequestPermissionValidator): """Always allows the request.""" - def __call__(self, request: List[str], context: Dict) -> bool: + def __call__(self, request: RequestFormat, context: Dict) -> bool: """Always allow the request.""" return True @@ -42,7 +43,7 @@ class RequestType(BaseModel): the request can be performed or not. """ - func: Callable[[List[str], Dict], None] + func: Callable[[RequestFormat, Dict], RequestResponse] """ ``func`` is a function that accepts a request and a context dict. Typically this would be a lambda function that invokes a class method of your SimComponent. For example if the component is a node and the request type is for @@ -71,7 +72,7 @@ class RequestManager(BaseModel): request_types: Dict[str, RequestType] = {} """maps request name to an RequestType object.""" - def __call__(self, request: Callable[[List[str], Dict], None], context: Dict) -> None: + def __call__(self, request: RequestFormat, context: Dict) -> RequestResponse: """ Process an request request. @@ -84,23 +85,23 @@ class RequestManager(BaseModel): :raises RuntimeError: If the request parameter does not have a valid request name as the first item. """ request_key = request[0] + request_options = request[1:] if request_key not in self.request_types: msg = ( f"Request {request} could not be processed because {request_key} is not a valid request name", "within this RequestManager", ) - _LOGGER.error(msg) - raise RuntimeError(msg) + _LOGGER.debug(msg) + return RequestResponse(status="unreachable", data={"reason": msg}) request_type = self.request_types[request_key] - request_options = request[1:] if not request_type.validator(request_options, context): _LOGGER.debug(f"Request {request} was denied due to insufficient permissions") - return + return RequestResponse(status="failure", data={"reason": "request validation failed"}) - request_type.func(request_options, context) + return request_type.func(request_options, context) def add_request(self, name: str, request_type: RequestType) -> None: """ @@ -153,22 +154,18 @@ class SimComponent(BaseModel): uuid: str = Field(default_factory=lambda: str(uuid4())) """The component UUID.""" - _original_state: Dict = {} - def __init__(self, **kwargs): super().__init__(**kwargs) self._request_manager: RequestManager = self._init_request_manager() self._parent: Optional["SimComponent"] = None - # @abstractmethod - def set_original_state(self): - """Sets the original state.""" - pass + def setup_for_episode(self, episode: int): + """ + Perform any additional setup on this component that can't happen during __init__. - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - for key, value in self._original_state.items(): - self.__setattr__(key, value) + For instance, some components may require for the entire network to exist before some configuration can be set. + """ + pass def _init_request_manager(self) -> RequestManager: """ @@ -206,7 +203,8 @@ class SimComponent(BaseModel): } return state - def apply_request(self, request: List[str], context: Dict = {}) -> None: + @validate_call + def apply_request(self, request: RequestFormat, context: Dict = {}) -> RequestResponse: """ Apply a request to a simulation component. Request data is passed in as a 'namespaced' list of strings. @@ -226,7 +224,7 @@ class SimComponent(BaseModel): """ if self._request_manager is None: return - self._request_manager(request, context) + return self._request_manager(request, context) def apply_timestep(self, timestep: int) -> None: """ diff --git a/src/primaite/simulator/domain/account.py b/src/primaite/simulator/domain/account.py index d9dad06a..186caf5b 100644 --- a/src/primaite/simulator/domain/account.py +++ b/src/primaite/simulator/domain/account.py @@ -42,19 +42,6 @@ class Account(SimComponent): "Account Type, currently this can be service account (used by apps) or user account." enabled: bool = True - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "num_logons", - "num_logoffs", - "num_group_changes", - "username", - "password", - "account_type", - "enabled", - } - self._original_state = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/domain/controller.py b/src/primaite/simulator/domain/controller.py index bc428743..432a1d9a 100644 --- a/src/primaite/simulator/domain/controller.py +++ b/src/primaite/simulator/domain/controller.py @@ -80,6 +80,11 @@ class DomainController(SimComponent): super().__init__(**kwargs) def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() # Action 'account' matches requests like: # ['account', '', *account_action] @@ -87,6 +92,7 @@ class DomainController(SimComponent): "account", RequestType( func=lambda request, context: self.accounts[request.pop(0)].apply_request(request, context), + # TODO: not sure what should get returned here, revisit validator=GroupMembershipValidator(allowed_groups=[AccountGroup.DOMAIN_ADMIN]), ), ) diff --git a/src/primaite/simulator/file_system/file.py b/src/primaite/simulator/file_system/file.py index 608a1d78..9331c40c 100644 --- a/src/primaite/simulator/file_system/file.py +++ b/src/primaite/simulator/file_system/file.py @@ -38,6 +38,8 @@ class File(FileSystemItemABC): "The Path if real is True." sim_root: Optional[Path] = None "Root path of the simulation." + num_access: int = 0 + "Number of times the file was accessed in the current step." def __init__(self, **kwargs): """ @@ -73,20 +75,6 @@ class File(FileSystemItemABC): self.sys_log.info(f"Created file /{self.path} (id: {self.uuid})") - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting File ({self.path}) original state on node {self.sys_log.hostname}") - super().set_original_state() - vals_to_include = {"folder_id", "folder_name", "file_type", "sim_size", "real", "sim_path", "sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting File ({self.path}) state on node {self.sys_log.hostname}") - super().reset_component_for_episode(episode) - @property def path(self) -> str: """ @@ -107,22 +95,36 @@ class File(FileSystemItemABC): return os.path.getsize(self.sim_path) return self.sim_size + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the file. + + :param timestep: The current timestep of the simulation. + """ + super().apply_timestep(timestep=timestep) + + # reset the number of accesses to 0 + self.num_access = 0 + def describe_state(self) -> Dict: """Produce a dictionary describing the current state of this object.""" state = super().describe_state() state["size"] = self.size state["file_type"] = self.file_type.name + state["num_access"] = self.num_access return state - def scan(self) -> None: + def scan(self) -> bool: """Updates the visible statuses of the file.""" if self.deleted: self.sys_log.error(f"Unable to scan deleted file {self.folder_name}/{self.name}") - return + return False + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Scanning file {self.sim_path if self.sim_path else path}") self.visible_health_status = self.health_status + return True def reveal_to_red(self) -> None: """Reveals the folder/file to the red agent.""" @@ -131,7 +133,7 @@ class File(FileSystemItemABC): return self.revealed_to_red = True - def check_hash(self) -> None: + def check_hash(self) -> bool: """ Check if the file has been changed. @@ -141,7 +143,7 @@ class File(FileSystemItemABC): """ if self.deleted: self.sys_log.error(f"Unable to check hash of deleted file {self.folder_name}/{self.name}") - return + return False current_hash = None # if file is real, read the file contents @@ -163,50 +165,59 @@ class File(FileSystemItemABC): # if the previous hash and current hash do not match, mark file as corrupted if self.previous_hash is not current_hash: self.corrupt() + return True - def repair(self) -> None: + def repair(self) -> bool: """Repair a corrupted File by setting the status to FileSystemItemStatus.GOOD.""" if self.deleted: self.sys_log.error(f"Unable to repair deleted file {self.folder_name}/{self.name}") - return + return False # set file status to good if corrupt if self.health_status == FileSystemItemHealthStatus.CORRUPT: self.health_status = FileSystemItemHealthStatus.GOOD + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Repaired file {self.sim_path if self.sim_path else path}") + return True - def corrupt(self) -> None: + def corrupt(self) -> bool: """Corrupt a File by setting the status to FileSystemItemStatus.CORRUPT.""" if self.deleted: self.sys_log.error(f"Unable to corrupt deleted file {self.folder_name}/{self.name}") - return + return False # set file status to good if corrupt if self.health_status == FileSystemItemHealthStatus.GOOD: self.health_status = FileSystemItemHealthStatus.CORRUPT + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Corrupted file {self.sim_path if self.sim_path else path}") + return True - def restore(self) -> None: + def restore(self) -> bool: """Determines if the file needs to be repaired or unmarked as deleted.""" if self.deleted: self.deleted = False - return + return True if self.health_status == FileSystemItemHealthStatus.CORRUPT: self.health_status = FileSystemItemHealthStatus.GOOD + self.num_access += 1 # file was accessed path = self.folder.name + "/" + self.name self.sys_log.info(f"Restored file {self.sim_path if self.sim_path else path}") + return True - def delete(self): + def delete(self) -> bool: """Marks the file as deleted.""" if self.deleted: self.sys_log.error(f"Unable to delete an already deleted file {self.folder_name}/{self.name}") - return + return False + self.num_access += 1 # file was accessed self.deleted = True self.sys_log.info(f"File deleted {self.folder_name}/{self.name}") + return True diff --git a/src/primaite/simulator/file_system/file_system.py b/src/primaite/simulator/file_system/file_system.py index ee80587d..9166178c 100644 --- a/src/primaite/simulator/file_system/file_system.py +++ b/src/primaite/simulator/file_system/file_system.py @@ -7,6 +7,7 @@ from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_type import FileType @@ -27,6 +28,10 @@ class FileSystem(SimComponent): "Instance of SysLog used to create system logs." sim_root: Path "Root path of the simulation." + num_file_creations: int = 0 + "Number of file creations in the current step." + num_file_deletions: int = 0 + "Number of file deletions in the current step." def __init__(self, **kwargs): super().__init__(**kwargs) @@ -34,56 +39,28 @@ class FileSystem(SimComponent): if not self.folders: self.create_folder("root") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FileSystem original state on node {self.sys_log.hostname}") - for folder in self.folders.values(): - folder.set_original_state() - # Capture a list of all 'original' file uuids - original_keys = list(self.folders.keys()) - vals_to_include = {"sim_root"} - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_folder_uuids"] = original_keys - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FileSystem state on node {self.sys_log.hostname}") - # Move any 'original' folder that have been deleted back to folders - original_folder_uuids = self._original_state["original_folder_uuids"] - for uuid in original_folder_uuids: - if uuid in self.deleted_folders: - folder = self.deleted_folders[uuid] - self.deleted_folders.pop(uuid) - self.folders[uuid] = folder - - # Clear any other deleted folders that aren't original (have been created by agent) - self.deleted_folders.clear() - - # Now clear all non-original folders created by agent - current_folder_uuids = list(self.folders.keys()) - for uuid in current_folder_uuids: - if uuid not in original_folder_uuids: - folder = self.folders[uuid] - self.folders.pop(uuid) - - # Now reset all remaining folders - for folder in self.folders.values(): - folder.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() self._delete_manager = RequestManager() self._delete_manager.add_request( name="file", request_type=RequestType( - func=lambda request, context: self.delete_file(folder_name=request[0], file_name=request[1]) + func=lambda request, context: RequestResponse.from_bool( + self.delete_file(folder_name=request[0], file_name=request[1]) + ) ), ) self._delete_manager.add_request( name="folder", - request_type=RequestType(func=lambda request, context: self.delete_folder(folder_name=request[0])), + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.delete_folder(folder_name=request[0])) + ), ) rm.add_request( name="delete", @@ -94,12 +71,16 @@ class FileSystem(SimComponent): self._restore_manager.add_request( name="file", request_type=RequestType( - func=lambda request, context: self.restore_file(folder_name=request[0], file_name=request[1]) + func=lambda request, context: RequestResponse.from_bool( + self.restore_file(folder_name=request[0], file_name=request[1]) + ) ), ) self._restore_manager.add_request( name="folder", - request_type=RequestType(func=lambda request, context: self.restore_folder(folder_name=request[0])), + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.restore_folder(folder_name=request[0])) + ), ) rm.add_request( name="restore", @@ -175,7 +156,7 @@ class FileSystem(SimComponent): ) return folder - def delete_folder(self, folder_name: str): + def delete_folder(self, folder_name: str) -> bool: """ Deletes a folder, removes it from the folders list and removes any child folders and files. @@ -183,24 +164,26 @@ class FileSystem(SimComponent): """ if folder_name == "root": self.sys_log.warning("Cannot delete the root folder.") - return + return False folder = self.get_folder(folder_name) - if folder: - # set folder to deleted state - folder.delete() - - # remove from folder list - self.folders.pop(folder.uuid) - - # add to deleted list - folder.remove_all_files() - - self.deleted_folders[folder.uuid] = folder - self.sys_log.info(f"Deleted folder /{folder.name} and its contents") - else: + if not folder: _LOGGER.debug(f"Cannot delete folder as it does not exist: {folder_name}") + return False - def delete_folder_by_id(self, folder_uuid: str): + # set folder to deleted state + folder.delete() + + # remove from folder list + self.folders.pop(folder.uuid) + + # add to deleted list + folder.remove_all_files() + + self.deleted_folders[folder.uuid] = folder + self.sys_log.info(f"Deleted folder /{folder.name} and its contents") + return True + + def delete_folder_by_id(self, folder_uuid: str) -> None: """ Deletes a folder via its uuid. @@ -285,6 +268,8 @@ class FileSystem(SimComponent): ) folder.add_file(file) self._file_request_manager.add_request(name=file.name, request_type=RequestType(func=file._request_manager)) + # increment file creation + self.num_file_creations += 1 return file def get_file(self, folder_name: str, file_name: str, include_deleted: Optional[bool] = False) -> Optional[File]: @@ -334,7 +319,7 @@ class FileSystem(SimComponent): return file - def delete_file(self, folder_name: str, file_name: str): + def delete_file(self, folder_name: str, file_name: str) -> bool: """ Delete a file by its name from a specific folder. @@ -345,9 +330,13 @@ class FileSystem(SimComponent): if folder: file = folder.get_file(file_name) if file: + # increment file creation + self.num_file_deletions += 1 folder.remove_file(file) + return True + return False - def delete_file_by_id(self, folder_uuid: str, file_uuid: str): + def delete_file_by_id(self, folder_uuid: str, file_uuid: str) -> None: """ Deletes a file via its uuid. @@ -364,7 +353,7 @@ class FileSystem(SimComponent): else: self.sys_log.error(f"Unable to delete file that does not exist. (id: {file_uuid})") - def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str): + def move_file(self, src_folder_name: str, src_file_name: str, dst_folder_name: str) -> None: """ Move a file from one folder to another. @@ -374,15 +363,14 @@ class FileSystem(SimComponent): """ file = self.get_file(folder_name=src_folder_name, file_name=src_file_name) if file: - src_folder = file.folder - # remove file from src - src_folder.remove_file(file) + self.delete_file(folder_name=file.folder_name, file_name=file.name) dst_folder = self.get_folder(folder_name=dst_folder_name) if not dst_folder: dst_folder = self.create_folder(dst_folder_name) # add file to dst dst_folder.add_file(file) + self.num_file_creations += 1 if file.real: old_sim_path = file.sim_path file.sim_path = file.sim_root / file.path @@ -410,6 +398,10 @@ class FileSystem(SimComponent): folder_name=dst_folder.name, **file.model_dump(exclude={"uuid", "folder_id", "folder_name", "sim_path"}), ) + self.num_file_creations += 1 + # increment access counter + file.num_access += 1 + dst_folder.add_file(file_copy, force=True) if file.real: @@ -427,12 +419,20 @@ class FileSystem(SimComponent): state = super().describe_state() state["folders"] = {folder.name: folder.describe_state() for folder in self.folders.values()} state["deleted_folders"] = {folder.name: folder.describe_state() for folder in self.deleted_folders.values()} + state["num_file_creations"] = self.num_file_creations + state["num_file_deletions"] = self.num_file_deletions return state def apply_timestep(self, timestep: int) -> None: """Apply time step to FileSystem and its child folders and files.""" super().apply_timestep(timestep=timestep) + # reset number of file creations + self.num_file_creations = 0 + + # reset number of file deletions + self.num_file_deletions = 0 + # apply timestep to folders for folder_id in self.folders: self.folders[folder_id].apply_timestep(timestep=timestep) @@ -441,7 +441,7 @@ class FileSystem(SimComponent): # Agent actions ############################################################### - def scan(self, instant_scan: bool = False): + def scan(self, instant_scan: bool = False) -> None: """ Scan all the folders (and child files) in the file system. @@ -450,7 +450,7 @@ class FileSystem(SimComponent): for folder_id in self.folders: self.folders[folder_id].scan(instant_scan=instant_scan) - def reveal_to_red(self, instant_scan: bool = False): + def reveal_to_red(self, instant_scan: bool = False) -> None: """ Reveals all the folders (and child files) in the file system to the red agent. @@ -459,7 +459,7 @@ class FileSystem(SimComponent): for folder_id in self.folders: self.folders[folder_id].reveal_to_red(instant_scan=instant_scan) - def restore_folder(self, folder_name: str): + def restore_folder(self, folder_name: str) -> bool: """ Restore a folder. @@ -472,13 +472,14 @@ class FileSystem(SimComponent): if folder is None: self.sys_log.error(f"Unable to restore folder {folder_name}. Folder is not in deleted folder list.") - return + return False self.deleted_folders.pop(folder.uuid, None) folder.restore() self.folders[folder.uuid] = folder + return True - def restore_file(self, folder_name: str, file_name: str): + def restore_file(self, folder_name: str, file_name: str) -> bool: """ Restore a file. @@ -491,12 +492,15 @@ class FileSystem(SimComponent): :type: file_name: str """ folder = self.get_folder(folder_name=folder_name) + if not folder: + _LOGGER.debug(f"Cannot restore file {file_name} in folder {folder_name} as the folder does not exist.") + return False - if folder: - file = folder.get_file(file_name=file_name, include_deleted=True) + file = folder.get_file(file_name=file_name, include_deleted=True) - if file is None: - self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.") - return + if not file: + msg = f"Unable to restore file {file_name}. File was not found." + self.sys_log.error(msg) + return False - folder.restore_file(file_name=file_name) + return folder.restore_file(file_name=file_name) diff --git a/src/primaite/simulator/file_system/file_system_item_abc.py b/src/primaite/simulator/file_system/file_system_item_abc.py index c3e1426b..32f5f6be 100644 --- a/src/primaite/simulator/file_system/file_system_item_abc.py +++ b/src/primaite/simulator/file_system/file_system_item_abc.py @@ -6,6 +6,7 @@ from enum import Enum from typing import Dict, Optional from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.system.core.sys_log import SysLog @@ -85,11 +86,6 @@ class FileSystemItemABC(SimComponent): deleted: bool = False "If true, the FileSystemItem was deleted." - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"name", "health_status", "visible_health_status", "previous_hash", "revealed_to_red", "deleted"} - self._original_state = self.model_dump(include=vals_to_keep) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -105,14 +101,33 @@ class FileSystemItemABC(SimComponent): return state def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() - rm.add_request(name="scan", request_type=RequestType(func=lambda request, context: self.scan())) - rm.add_request(name="checkhash", request_type=RequestType(func=lambda request, context: self.check_hash())) - rm.add_request(name="repair", request_type=RequestType(func=lambda request, context: self.repair())) - rm.add_request(name="restore", request_type=RequestType(func=lambda request, context: self.restore())) + rm.add_request( + name="scan", request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan())) + ) + rm.add_request( + name="checkhash", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.check_hash())), + ) + rm.add_request( + name="repair", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.repair())), + ) + rm.add_request( + name="restore", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.restore())), + ) - rm.add_request(name="corrupt", request_type=RequestType(func=lambda request, context: self.corrupt())) + rm.add_request( + name="corrupt", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.corrupt())), + ) return rm @@ -129,9 +144,9 @@ class FileSystemItemABC(SimComponent): return convert_size(self.size) @abstractmethod - def scan(self) -> None: + def scan(self) -> bool: """Scan the folder/file - updates the visible_health_status.""" - pass + return False @abstractmethod def reveal_to_red(self) -> None: @@ -139,7 +154,7 @@ class FileSystemItemABC(SimComponent): pass @abstractmethod - def check_hash(self) -> None: + def check_hash(self) -> bool: """ Checks the has of the file to detect any changes. @@ -147,30 +162,30 @@ class FileSystemItemABC(SimComponent): Return False if corruption is detected, otherwise True """ - pass + return False @abstractmethod - def repair(self) -> None: + def repair(self) -> bool: """ Repair the FileSystemItem. True if successfully repaired. False otherwise. """ - pass + return False @abstractmethod - def corrupt(self) -> None: + def corrupt(self) -> bool: """ Corrupt the FileSystemItem. True if successfully corrupted. False otherwise. """ - pass + return False @abstractmethod - def restore(self) -> None: + def restore(self) -> bool: """Restore the file/folder to the state before it got ruined.""" - pass + return False @abstractmethod def delete(self) -> None: diff --git a/src/primaite/simulator/file_system/folder.py b/src/primaite/simulator/file_system/folder.py index 13fdc597..6ebd8d14 100644 --- a/src/primaite/simulator/file_system/folder.py +++ b/src/primaite/simulator/file_system/folder.py @@ -5,6 +5,7 @@ from typing import Dict, Optional from prettytable import MARKDOWN, PrettyTable from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_system_item_abc import FileSystemItemABC, FileSystemItemHealthStatus @@ -49,54 +50,18 @@ class Folder(FileSystemItemABC): self.sys_log.info(f"Created file /{self.name} (id: {self.uuid})") - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting Folder ({self.name}) original state on node {self.sys_log.hostname}") - for file in self.files.values(): - file.set_original_state() - super().set_original_state() - vals_to_include = { - "scan_duration", - "scan_countdown", - "red_scan_duration", - "red_scan_countdown", - "restore_duration", - "restore_countdown", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - self._original_state["original_file_uuids"] = list(self.files.keys()) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting Folder ({self.name}) state on node {self.sys_log.hostname}") - # Move any 'original' file that have been deleted back to files - original_file_uuids = self._original_state["original_file_uuids"] - for uuid in original_file_uuids: - if uuid in self.deleted_files: - file = self.deleted_files[uuid] - self.deleted_files.pop(uuid) - self.files[uuid] = file - - # Clear any other deleted files that aren't original (have been created by agent) - self.deleted_files.clear() - - # Now clear all non-original files created by agent - current_file_uuids = list(self.files.keys()) - for uuid in current_file_uuids: - if uuid not in original_file_uuids: - file = self.files[uuid] - self.files.pop(uuid) - - # Now reset all remaining files - for file in self.files.values(): - file.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() rm.add_request( name="delete", - request_type=RequestType(func=lambda request, context: self.remove_file_by_id(file_uuid=request[0])), + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.remove_file_by_name(file_name=request[0])) + ), ) self._file_request_manager = RequestManager() rm.add_request( @@ -173,7 +138,8 @@ class Folder(FileSystemItemABC): file = self.get_file_by_id(file_uuid=file_id) file.scan() if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: - self.visible_health_status = FileSystemItemHealthStatus.CORRUPT + self.health_status = FileSystemItemHealthStatus.CORRUPT + self.visible_health_status = self.health_status def _reveal_to_red_timestep(self) -> None: """Apply reveal to red timestep.""" @@ -292,6 +258,21 @@ class Folder(FileSystemItemABC): file = self.get_file_by_id(file_uuid=file_uuid) self.remove_file(file=file) + def remove_file_by_name(self, file_name: str) -> bool: + """ + Remove a file using its name. + + :param file_name: filename + :type file_name: str + :return: Whether it was successfully removed. + :rtype: bool + """ + for f in self.files.values(): + if f.name == file_name: + self.remove_file(f) + return True + return False + def remove_all_files(self): """Removes all the files in the folder.""" for file_id in self.files: @@ -301,7 +282,7 @@ class Folder(FileSystemItemABC): self.files = {} - def restore_file(self, file_name: str): + def restore_file(self, file_name: str) -> bool: """ Restores a file. @@ -311,13 +292,14 @@ class Folder(FileSystemItemABC): file = self.get_file(file_name=file_name, include_deleted=True) if not file: self.sys_log.error(f"Unable to restore file {file_name}. File does not exist.") - return + return False file.restore() self.files[file.uuid] = file if file.deleted: self.deleted_files.pop(file.uuid) + return True def quarantine(self): """Quarantines the File System Folder.""" @@ -331,7 +313,7 @@ class Folder(FileSystemItemABC): """Returns true if the folder is being quarantined.""" pass - def scan(self, instant_scan: bool = False) -> None: + def scan(self, instant_scan: bool = False) -> bool: """ Update Folder visible status. @@ -339,7 +321,7 @@ class Folder(FileSystemItemABC): """ if self.deleted: self.sys_log.error(f"Unable to scan deleted folder {self.name}") - return + return False if instant_scan: for file_id in self.files: @@ -347,7 +329,7 @@ class Folder(FileSystemItemABC): file.scan() if file.visible_health_status == FileSystemItemHealthStatus.CORRUPT: self.visible_health_status = FileSystemItemHealthStatus.CORRUPT - return + return True if self.scan_countdown <= 0: # scan one file per timestep @@ -356,6 +338,7 @@ class Folder(FileSystemItemABC): else: # scan already in progress self.sys_log.info(f"Scan is already in progress {self.name} (id: {self.uuid})") + return True def reveal_to_red(self, instant_scan: bool = False): """ @@ -382,7 +365,7 @@ class Folder(FileSystemItemABC): # scan already in progress self.sys_log.info(f"Red Agent Scan is already in progress {self.name} (id: {self.uuid})") - def check_hash(self) -> None: + def check_hash(self) -> bool: """ Runs a :func:`check_hash` on all files in the folder. @@ -395,7 +378,7 @@ class Folder(FileSystemItemABC): """ if self.deleted: self.sys_log.error(f"Unable to check hash of deleted folder {self.name}") - return + return False # iterate through the files and run a check hash no_corrupted_files = True @@ -411,12 +394,13 @@ class Folder(FileSystemItemABC): self.corrupt() self.sys_log.info(f"Checking hash of folder {self.name} (id: {self.uuid})") + return True - def repair(self) -> None: + def repair(self) -> bool: """Repair a corrupted Folder by setting the folder and containing files status to FileSystemItemStatus.GOOD.""" if self.deleted: self.sys_log.error(f"Unable to repair deleted folder {self.name}") - return + return False # iterate through the files in the folder for file_id in self.files: @@ -430,8 +414,9 @@ class Folder(FileSystemItemABC): self.health_status = FileSystemItemHealthStatus.GOOD self.sys_log.info(f"Repaired folder {self.name} (id: {self.uuid})") + return True - def restore(self) -> None: + def restore(self) -> bool: """ If a Folder is corrupted, run a repair on the folder and its child files. @@ -447,12 +432,13 @@ class Folder(FileSystemItemABC): else: # scan already in progress self.sys_log.info(f"Folder restoration already in progress {self.name} (id: {self.uuid})") + return True - def corrupt(self) -> None: + def corrupt(self) -> bool: """Corrupt a File by setting the folder and containing files status to FileSystemItemStatus.CORRUPT.""" if self.deleted: self.sys_log.error(f"Unable to corrupt deleted folder {self.name}") - return + return False # iterate through the files in the folder for file_id in self.files: @@ -463,11 +449,13 @@ class Folder(FileSystemItemABC): self.health_status = FileSystemItemHealthStatus.CORRUPT self.sys_log.info(f"Corrupted folder {self.name} (id: {self.uuid})") + return True - def delete(self): + def delete(self) -> bool: """Marks the file as deleted. Prevents agent actions from occuring.""" if self.deleted: self.sys_log.error(f"Unable to delete an already deleted folder {self.name}") - return + return False self.deleted = True + return True diff --git a/src/primaite/simulator/network/airspace.py b/src/primaite/simulator/network/airspace.py index 724b8728..a8343675 100644 --- a/src/primaite/simulator/network/airspace.py +++ b/src/primaite/simulator/network/airspace.py @@ -157,7 +157,7 @@ class WirelessNetworkInterface(NetworkInterface, ABC): return if not self._connected_node: - _LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node") + _LOGGER.warning(f"Interface {self} cannot be enabled as it is not connected to a Node") return if self._connected_node.operating_state != NodeOperatingState.ON: @@ -168,7 +168,9 @@ class WirelessNetworkInterface(NetworkInterface, ABC): self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") - self.pcap = PacketCapture(hostname=self._connected_node.hostname, interface_num=self.port_num) + self.pcap = PacketCapture( + hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + ) AIR_SPACE.add_wireless_interface(self) def disable(self): @@ -273,11 +275,6 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC) return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def enable(self): """ Enables this wired network interface and attempts to send a "hello" message to the default gateway. diff --git a/src/primaite/simulator/network/container.py b/src/primaite/simulator/network/container.py index b32d2630..0e970c3d 100644 --- a/src/primaite/simulator/network/container.py +++ b/src/primaite/simulator/network/container.py @@ -8,10 +8,6 @@ from prettytable import MARKDOWN, PrettyTable from primaite import getLogger from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.network.hardware.base import Link, Node, WiredNetworkInterface -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.hardware.nodes.network.router import Router -from primaite.simulator.network.hardware.nodes.network.switch import Switch from primaite.simulator.system.applications.application import Application from primaite.simulator.system.services.service import Service @@ -45,19 +41,12 @@ class Network(SimComponent): self._nx_graph = MultiGraph() - def set_original_state(self): - """Sets the original state.""" - for node in self.nodes.values(): - node.set_original_state() - for link in self.links.values(): - link.set_original_state() - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" for node in self.nodes.values(): - node.reset_component_for_episode(episode) + node.setup_for_episode(episode=episode) for link in self.links.values(): - link.reset_component_for_episode(episode) + link.setup_for_episode(episode=episode) for node in self.nodes.values(): node.power_on() @@ -72,6 +61,11 @@ class Network(SimComponent): software.run() def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() self._node_request_manager = RequestManager() rm.add_request( @@ -92,24 +86,29 @@ class Network(SimComponent): self.links[link_id].apply_timestep(timestep=timestep) @property - def routers(self) -> List[Router]: + def router_nodes(self) -> List[Node]: """The Routers in the Network.""" - return [node for node in self.nodes.values() if isinstance(node, Router)] + return [node for node in self.nodes.values() if node.__class__.__name__ == "Router"] @property - def switches(self) -> List[Switch]: + def switch_nodes(self) -> List[Node]: """The Switches in the Network.""" - return [node for node in self.nodes.values() if isinstance(node, Switch)] + return [node for node in self.nodes.values() if node.__class__.__name__ == "Switch"] @property - def computers(self) -> List[Computer]: + def computer_nodes(self) -> List[Node]: """The Computers in the Network.""" - return [node for node in self.nodes.values() if isinstance(node, Computer) and not isinstance(node, Server)] + return [node for node in self.nodes.values() if node.__class__.__name__ == "Computer"] @property - def servers(self) -> List[Server]: + def server_nodes(self) -> List[Node]: """The Servers in the Network.""" - return [node for node in self.nodes.values() if isinstance(node, Server)] + return [node for node in self.nodes.values() if node.__class__.__name__ == "Server"] + + @property + def firewall_nodes(self) -> List[Node]: + """The Firewalls in the Network.""" + return [node for node in self.nodes.values() if node.__class__.__name__ == "Firewall"] def show(self, nodes: bool = True, ip_addresses: bool = True, links: bool = True, markdown: bool = False): """ @@ -124,10 +123,11 @@ class Network(SimComponent): :param markdown: Use Markdown style in table output. Defaults to False. """ nodes_type_map = { - "Router": self.routers, - "Switch": self.switches, - "Server": self.servers, - "Computer": self.computers, + "Router": self.router_nodes, + "Firewall": self.firewall_nodes, + "Switch": self.switch_nodes, + "Server": self.server_nodes, + "Computer": self.computer_nodes, } if nodes: table = PrettyTable(["Node", "Type", "Operating State"]) @@ -150,7 +150,10 @@ class Network(SimComponent): for node in nodes: for i, port in node.network_interface.items(): if hasattr(port, "ip_address"): - table.add_row([node.hostname, i, port.ip_address, port.subnet_mask, node.default_gateway]) + port_str = port.port_name if port.port_name else port.port_num + table.add_row( + [node.hostname, port_str, port.ip_address, port.subnet_mask, node.default_gateway] + ) print(table) if links: @@ -179,7 +182,7 @@ class Network(SimComponent): def clear_links(self): """Clear all the links in the network by resetting their component state for the episode.""" for link in self.links.values(): - link.reset_component_for_episode() + link.setup_for_episode(episode=0) # TODO: shouldn't be using this method here. def draw(self, seed: int = 123): """ diff --git a/src/primaite/simulator/network/hardware/base.py b/src/primaite/simulator/network/hardware/base.py index 7354725a..0cad4124 100644 --- a/src/primaite/simulator/network/hardware/base.py +++ b/src/primaite/simulator/network/hardware/base.py @@ -12,11 +12,21 @@ from pydantic import BaseModel, Field from primaite import getLogger from primaite.exceptions import NetworkError +from primaite.interface.request import RequestResponse from primaite.simulator import SIM_OUTPUT -from primaite.simulator.core import RequestManager, RequestType, SimComponent +from primaite.simulator.core import RequestFormat, RequestManager, RequestPermissionValidator, RequestType, SimComponent from primaite.simulator.domain.account import Account from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.nmne import ( + CAPTURE_BY_DIRECTION, + CAPTURE_BY_IP_ADDRESS, + CAPTURE_BY_KEYWORD, + CAPTURE_BY_PORT, + CAPTURE_BY_PROTOCOL, + CAPTURE_NMNE, + NMNE_CAPTURE_KEYWORDS, +) from primaite.simulator.network.transmission.data_link_layer import Frame from primaite.simulator.system.applications.application import Application from primaite.simulator.system.core.packet_capture import PacketCapture @@ -85,14 +95,34 @@ class NetworkInterface(SimComponent, ABC): port_num: Optional[int] = None "The port number assigned to this interface on the connected node." + port_name: Optional[str] = None + "The port name assigned to this interface on the connected node." + pcap: Optional[PacketCapture] = None "A PacketCapture instance for capturing and analysing packets passing through this interface." + nmne: Dict = Field(default_factory=lambda: {}) + "A dict containing details of the number of malicious network events captured." + + def setup_for_episode(self, episode: int): + """Reset the original state of the SimComponent.""" + super().setup_for_episode(episode=episode) + self.nmne = {} + if episode and self.pcap and SIM_OUTPUT.save_pcap_logs: + self.pcap.current_episode = episode + self.pcap.setup_logger() + self.enable() + def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() - rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) - rm.add_request("disable", RequestType(func=lambda request, context: self.disable())) + rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) + rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) return rm @@ -111,25 +141,97 @@ class NetworkInterface(SimComponent, ABC): "enabled": self.enabled, } ) + if CAPTURE_NMNE: + state.update({"nmne": {k: v for k, v in self.nmne.items()}}) return state - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) - if episode and self.pcap: - self.pcap.current_episode = episode - self.pcap.setup_logger() - self.enable() - @abstractmethod - def enable(self): + def enable(self) -> bool: """Enable the interface.""" pass + return False @abstractmethod - def disable(self): + def disable(self) -> bool: """Disable the interface.""" pass + return False + + def _capture_nmne(self, frame: Frame, inbound: bool = True) -> None: + """ + Processes and captures network frame data based on predefined global NMNE settings. + + This method updates the NMNE structure with counts of malicious network events based on the frame content and + direction. The structure is dynamically adjusted according to the enabled capture settings. + + .. note:: + While there is a lot of logic in this code that defines a multi-level hierarchical NMNE structure, + most of it is unused for now as a result of all `CAPTURE_BY_<>` variables in + ``primaite.simulator.network.nmne`` being hardcoded and set as final. Once they're 'released' and made + configurable, this function will be updated to properly explain the dynamic data structure. + + :param frame: The network frame to process, containing IP, TCP/UDP, and payload information. + :param inbound: Boolean indicating if the frame direction is inbound. Defaults to True. + """ + # Exit function if NMNE capturing is disabled + if not CAPTURE_NMNE: + return + + # Initialise basic frame data variables + direction = "inbound" if inbound else "outbound" # Direction of the traffic + ip_address = str(frame.ip.src_ip_address if inbound else frame.ip.dst_ip_address) # Source or destination IP + protocol = frame.ip.protocol.name # Network protocol used in the frame + + # Initialise port variable; will be determined based on protocol type + port = None + + # Determine the source or destination port based on the protocol (TCP/UDP) + if frame.tcp: + port = frame.tcp.src_port.value if inbound else frame.tcp.dst_port.value + elif frame.udp: + port = frame.udp.src_port.value if inbound else frame.udp.dst_port.value + + # Convert frame payload to string for keyword checking + frame_str = str(frame.payload) + + # Proceed only if any NMNE keyword is present in the frame payload + if any(keyword in frame_str for keyword in NMNE_CAPTURE_KEYWORDS): + # Start with the root of the NMNE capture structure + current_level = self.nmne + + # Update NMNE structure based on enabled settings + if CAPTURE_BY_DIRECTION: + # Set or get the dictionary for the current direction + current_level = current_level.setdefault("direction", {}) + current_level = current_level.setdefault(direction, {}) + + if CAPTURE_BY_IP_ADDRESS: + # Set or get the dictionary for the current IP address + current_level = current_level.setdefault("ip_address", {}) + current_level = current_level.setdefault(ip_address, {}) + + if CAPTURE_BY_PROTOCOL: + # Set or get the dictionary for the current protocol + current_level = current_level.setdefault("protocol", {}) + current_level = current_level.setdefault(protocol, {}) + + if CAPTURE_BY_PORT: + # Set or get the dictionary for the current port + current_level = current_level.setdefault("port", {}) + current_level = current_level.setdefault(port, {}) + + # Ensure 'KEYWORD' level is present in the structure + keyword_level = current_level.setdefault("keywords", {}) + + # Increment the count for detected keywords in the payload + if CAPTURE_BY_KEYWORD: + for keyword in NMNE_CAPTURE_KEYWORDS: + if keyword in frame_str: + # Update the count for each keyword found + keyword_level[keyword] = keyword_level.get(keyword, 0) + 1 + else: + # Increment a generic counter if keyword capturing is not enabled + keyword_level["*"] = keyword_level.get("*", 0) + 1 @abstractmethod def send_frame(self, frame: Frame) -> bool: @@ -139,7 +241,7 @@ class NetworkInterface(SimComponent, ABC): :param frame: The network frame to be sent. :return: A boolean indicating whether the frame was successfully sent. """ - pass + self._capture_nmne(frame, inbound=False) @abstractmethod def receive_frame(self, frame: Frame) -> bool: @@ -149,7 +251,7 @@ class NetworkInterface(SimComponent, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + self._capture_nmne(frame, inbound=True) def __str__(self) -> str: """ @@ -157,7 +259,15 @@ class NetworkInterface(SimComponent, ABC): :return: A string combining the port number and the mac address """ - return f"Port {self.port_num}: {self.mac_address}" + return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}" + + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep evolution to this component. + + This just clears the nmne count back to 0. + """ + super().apply_timestep(timestep=timestep) class WiredNetworkInterface(NetworkInterface, ABC): @@ -181,35 +291,38 @@ class WiredNetworkInterface(NetworkInterface, ABC): _connected_link: Optional[Link] = None "The network link to which the network interface is connected." - def enable(self): + def enable(self) -> bool: """Attempt to enable the network interface.""" if self.enabled: - return + return True if not self._connected_node: - _LOGGER.error(f"Interface {self} cannot be enabled as it is not connected to a Node") - return + _LOGGER.warning(f"Interface {self} cannot be enabled as it is not connected to a Node") + return False if self._connected_node.operating_state != NodeOperatingState.ON: self._connected_node.sys_log.info( f"Interface {self} cannot be enabled as the connected Node is not powered on" ) - return + return False if not self._connected_link: self._connected_node.sys_log.info(f"Interface {self} cannot be enabled as there is no Link connected.") - return + return False self.enabled = True self._connected_node.sys_log.info(f"Network Interface {self} enabled") - self.pcap = PacketCapture(hostname=self._connected_node.hostname, interface_num=self.port_num) + self.pcap = PacketCapture( + hostname=self._connected_node.hostname, port_num=self.port_num, port_name=self.port_name + ) if self._connected_link: self._connected_link.endpoint_up() + return True - def disable(self): + def disable(self) -> bool: """Disable the network interface.""" if not self.enabled: - return + return True self.enabled = False if self._connected_node: self._connected_node.sys_log.info(f"Network Interface {self} disabled") @@ -217,6 +330,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): _LOGGER.debug(f"Interface {self} disabled") if self._connected_link: self._connected_link.endpoint_down() + return True def connect_link(self, link: Link): """ @@ -229,11 +343,11 @@ class WiredNetworkInterface(NetworkInterface, ABC): :param link: The Link instance to connect to this network interface. """ if self._connected_link: - _LOGGER.error(f"Cannot connect Link to network interface {self} as it already has a connection") + _LOGGER.warning(f"Cannot connect Link to network interface {self} as it already has a connection") return if self._connected_link == link: - _LOGGER.error(f"Cannot connect Link to network interface {self} as it is already connected") + _LOGGER.warning(f"Cannot connect Link to network interface {self} as it is already connected") return self._connected_link = link @@ -263,6 +377,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): :param frame: The network frame to be sent. :return: True if the frame is sent, False if the Network Interface is disabled or not connected to a link. """ + super().send_frame(frame) if self.enabled: frame.set_sent_timestamp() self.pcap.capture_outbound(frame) @@ -279,7 +394,7 @@ class WiredNetworkInterface(NetworkInterface, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + return super().receive_frame(frame) class Layer3Interface(BaseModel, ABC): @@ -390,7 +505,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): return state - def enable(self): + def enable(self) -> bool: """ Enables this wired network interface and attempts to send a "hello" message to the default gateway. @@ -406,10 +521,12 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): try: pass self._connected_node.default_gateway_hello() + return True except AttributeError: pass + return False - # @abstractmethod + @abstractmethod def receive_frame(self, frame: Frame) -> bool: """ Receives a network frame on the network interface. @@ -417,7 +534,7 @@ class IPWiredNetworkInterface(WiredNetworkInterface, Layer3Interface, ABC): :param frame: The network frame being received. :return: A boolean indicating whether the frame was successfully received. """ - pass + return super().receive_frame(frame) class Link(SimComponent): @@ -455,14 +572,6 @@ class Link(SimComponent): self.endpoint_b.connect_link(self) self.endpoint_up() - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"bandwidth", "current_load"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -548,6 +657,11 @@ class Link(SimComponent): def __str__(self) -> str: return f"{self.endpoint_a}<-->{self.endpoint_b}" + def apply_timestep(self, timestep: int) -> None: + """Apply a timestep to the simulation.""" + super().apply_timestep(timestep) + self.current_load = 0.0 + class Node(SimComponent): """ @@ -643,84 +757,93 @@ class Node(SimComponent): self.session_manager.node = self self.session_manager.software_manager = self.software_manager self._install_system_software() - self.set_original_state() - def set_original_state(self): - """Sets the original state.""" - for software in self.software_manager.software.values(): - software.set_original_state() - - self.file_system.set_original_state() - - for network_interface in self.network_interfaces.values(): - network_interface.set_original_state() - - vals_to_include = { - "hostname", - "default_gateway", - "operating_state", - "revealed_to_red", - "start_up_duration", - "start_up_countdown", - "shut_down_duration", - "shut_down_countdown", - "is_resetting", - "node_scan_duration", - "node_scan_countdown", - "red_scan_countdown", - } - self._original_state = self.model_dump(include=vals_to_include) - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - super().reset_component_for_episode(episode) - - # Reset Session Manager - self.session_manager.clear() + super().setup_for_episode(episode=episode) # Reset File System - self.file_system.reset_component_for_episode(episode) + self.file_system.setup_for_episode(episode=episode) # Reset all Nics for network_interface in self.network_interfaces.values(): - network_interface.reset_component_for_episode(episode) + network_interface.setup_for_episode(episode=episode) for software in self.software_manager.software.values(): - software.reset_component_for_episode(episode) + software.setup_for_episode(episode=episode) if episode and self.sys_log: self.sys_log.current_episode = episode self.sys_log.setup_logger() + class _NodeIsOnValidator(RequestPermissionValidator): + """ + When requests come in, this validator will only let them through if the node is on. + + This is useful because no actions should be being resolved if the node is off. + """ + + node: Node + """Save a reference to the node instance.""" + + def __call__(self, request: RequestFormat, context: Dict) -> bool: + """Return whether the node is on or off.""" + return self.node.operating_state == NodeOperatingState.ON + def _init_request_manager(self) -> RequestManager: - # TODO: I see that this code is really confusing and hard to read right now... I think some of these things will - # need a better name and better documentation. + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ + _node_is_on = Node._NodeIsOnValidator(node=self) + rm = super()._init_request_manager() # since there are potentially many services, create an request manager that can map service name self._service_request_manager = RequestManager() - rm.add_request("service", RequestType(func=self._service_request_manager)) + rm.add_request("service", RequestType(func=self._service_request_manager, validator=_node_is_on)) self._nic_request_manager = RequestManager() - rm.add_request("network_interface", RequestType(func=self._nic_request_manager)) + rm.add_request("network_interface", RequestType(func=self._nic_request_manager, validator=_node_is_on)) - rm.add_request("file_system", RequestType(func=self.file_system._request_manager)) + rm.add_request("file_system", RequestType(func=self.file_system._request_manager, validator=_node_is_on)) # currently we don't have any applications nor processes, so these will be empty self._process_request_manager = RequestManager() - rm.add_request("process", RequestType(func=self._process_request_manager)) + rm.add_request("process", RequestType(func=self._process_request_manager, validator=_node_is_on)) self._application_request_manager = RequestManager() - rm.add_request("application", RequestType(func=self._application_request_manager)) + rm.add_request("application", RequestType(func=self._application_request_manager, validator=_node_is_on)) - rm.add_request("scan", RequestType(func=lambda request, context: self.reveal_to_red())) + rm.add_request( + "scan", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.reveal_to_red()), validator=_node_is_on + ), + ) - rm.add_request("shutdown", RequestType(func=lambda request, context: self.power_off())) - rm.add_request("startup", RequestType(func=lambda request, context: self.power_on())) - rm.add_request("reset", RequestType(func=lambda request, context: self.reset())) # TODO implement node reset - rm.add_request("logon", RequestType(func=lambda request, context: ...)) # TODO implement logon request - rm.add_request("logoff", RequestType(func=lambda request, context: ...)) # TODO implement logoff request + rm.add_request( + "shutdown", + RequestType( + func=lambda request, context: RequestResponse.from_bool(self.power_off()), validator=_node_is_on + ), + ) + rm.add_request("startup", RequestType(func=lambda request, context: RequestResponse.from_bool(self.power_on()))) + rm.add_request( + "reset", + RequestType(func=lambda request, context: RequestResponse.from_bool(self.reset()), validator=_node_is_on), + ) # TODO implement node reset + rm.add_request( + "logon", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) + ) # TODO implement logon request + rm.add_request( + "logoff", RequestType(func=lambda request, context: RequestResponse.from_bool(False), validator=_node_is_on) + ) # TODO implement logoff request self._os_request_manager = RequestManager() - self._os_request_manager.add_request("scan", RequestType(func=lambda request, context: self.scan())) - rm.add_request("os", RequestType(func=self._os_request_manager)) + self._os_request_manager.add_request( + "scan", + RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()), validator=_node_is_on), + ) + rm.add_request("os", RequestType(func=self._os_request_manager, validator=_node_is_on)) return rm @@ -819,6 +942,9 @@ class Node(SimComponent): """ super().apply_timestep(timestep=timestep) + for network_interface in self.network_interfaces.values(): + network_interface.apply_timestep(timestep=timestep) + # count down to boot up if self.start_up_countdown > 0: self.start_up_countdown -= 1 @@ -897,7 +1023,7 @@ class Node(SimComponent): self.file_system.apply_timestep(timestep=timestep) - def scan(self) -> None: + def scan(self) -> bool: """ Scan the node and all the items within it. @@ -911,8 +1037,9 @@ class Node(SimComponent): to the red agent. """ self.node_scan_countdown = self.node_scan_duration + return True - def reveal_to_red(self) -> None: + def reveal_to_red(self) -> bool: """ Reveals the node and all the items within it to the red agent. @@ -926,34 +1053,40 @@ class Node(SimComponent): `revealed_to_red` to `True`. """ self.red_scan_countdown = self.node_scan_duration + return True - def power_on(self): + def power_on(self) -> bool: """Power on the Node, enabling its NICs if it is in the OFF state.""" - if self.operating_state == NodeOperatingState.OFF: - self.operating_state = NodeOperatingState.BOOTING - self.start_up_countdown = self.start_up_duration - if self.start_up_duration <= 0: self.operating_state = NodeOperatingState.ON self._start_up_actions() self.sys_log.info("Power on") for network_interface in self.network_interfaces.values(): network_interface.enable() + return True + if self.operating_state == NodeOperatingState.OFF: + self.operating_state = NodeOperatingState.BOOTING + self.start_up_countdown = self.start_up_duration + return True - def power_off(self): + return False + + def power_off(self) -> bool: """Power off the Node, disabling its NICs if it is in the ON state.""" + if self.shut_down_duration <= 0: + self._shut_down_actions() + self.operating_state = NodeOperatingState.OFF + self.sys_log.info("Power off") + return True if self.operating_state == NodeOperatingState.ON: for network_interface in self.network_interfaces.values(): network_interface.disable() self.operating_state = NodeOperatingState.SHUTTING_DOWN self.shut_down_countdown = self.shut_down_duration + return True + return False - if self.shut_down_duration <= 0: - self._shut_down_actions() - self.operating_state = NodeOperatingState.OFF - self.sys_log.info("Power off") - - def reset(self): + def reset(self) -> bool: """ Resets the node. @@ -964,8 +1097,10 @@ class Node(SimComponent): self.is_resetting = True self.sys_log.info("Resetting") self.power_off() + return True + return False - def connect_nic(self, network_interface: NetworkInterface): + def connect_nic(self, network_interface: NetworkInterface, port_name: Optional[str] = None): """ Connect a Network Interface to the node. @@ -977,7 +1112,9 @@ class Node(SimComponent): new_nic_num = len(self.network_interfaces) self.network_interface[new_nic_num] = network_interface network_interface._connected_node = self - network_interface._port_num_on_node = new_nic_num + network_interface.port_num = new_nic_num + if port_name: + network_interface.port_name = port_name network_interface.parent = self self.sys_log.info(f"Connected Network Interface {network_interface}") if self.operating_state == NodeOperatingState.ON: diff --git a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py index bc24270e..4b73b6a8 100644 --- a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py +++ b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_access_point.py @@ -51,13 +51,15 @@ class WirelessAccessPoint(IPWirelessNetworkInterface): return state - def enable(self): + def enable(self) -> bool: """Enable the interface.""" pass + return True - def disable(self): + def disable(self) -> bool: """Disable the interface.""" pass + return True def send_frame(self, frame: Frame) -> bool: """ @@ -83,4 +85,4 @@ class WirelessAccessPoint(IPWirelessNetworkInterface): :return: A string combining the port number, MAC address and IP address of the NIC. """ - return f"Port {self.port_num}: {self.mac_address}/{self.ip_address}" + return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}" diff --git a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py index 32acc08a..2e0a1823 100644 --- a/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py +++ b/src/primaite/simulator/network/hardware/network_interface/wireless/wireless_nic.py @@ -48,13 +48,15 @@ class WirelessNIC(IPWirelessNetworkInterface): return state - def enable(self): + def enable(self) -> bool: """Enable the interface.""" pass + return True - def disable(self): + def disable(self) -> bool: """Disable the interface.""" pass + return True def send_frame(self, frame: Frame) -> bool: """ @@ -80,4 +82,4 @@ class WirelessNIC(IPWirelessNetworkInterface): :return: A string combining the port number, MAC address and IP address of the NIC. """ - return f"Port {self.port_num}: {self.mac_address}/{self.ip_address}" + return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}" diff --git a/src/primaite/simulator/network/hardware/nodes/host/host_node.py b/src/primaite/simulator/network/hardware/nodes/host/host_node.py index 3f34f736..31378689 100644 --- a/src/primaite/simulator/network/hardware/nodes/host/host_node.py +++ b/src/primaite/simulator/network/hardware/nodes/host/host_node.py @@ -1,7 +1,7 @@ from __future__ import annotations from ipaddress import IPv4Address -from typing import Any, Dict, Optional +from typing import Any, ClassVar, Dict, Optional from primaite import getLogger from primaite.simulator.network.hardware.base import IPWiredNetworkInterface, Link, Node @@ -205,19 +205,10 @@ class NIC(IPWiredNetworkInterface): state = super().describe_state() # Update the state with NIC-specific information - state.update( - { - "wake_on_lan": self.wake_on_lan, - } - ) + state.update({"wake_on_lan": self.wake_on_lan}) return state - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"ip_address", "subnet_mask", "mac_address", "speed", "mtu", "wake_on_lan", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - def receive_frame(self, frame: Frame) -> bool: """ Attempt to receive and process a network frame from the connected Link. @@ -248,6 +239,7 @@ class NIC(IPWiredNetworkInterface): accept_frame = True if accept_frame: + super().receive_frame(frame) self._connected_node.receive_frame(frame=frame, from_network_interface=self) return True return False @@ -258,7 +250,7 @@ class NIC(IPWiredNetworkInterface): :return: A string combining the port number, MAC address and IP address of the NIC. """ - return f"Port {self.port_num}: {self.mac_address}/{self.ip_address}" + return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}" class HostNode(Node): @@ -305,6 +297,16 @@ class HostNode(Node): * Web Browser: Provides web browsing capabilities. """ + SYSTEM_SOFTWARE: ClassVar[Dict] = { + "HostARP": HostARP, + "ICMP": ICMP, + "DNSClient": DNSClient, + "FTPClient": FTPClient, + "NTPClient": NTPClient, + "WebBrowser": WebBrowser, + } + """List of system software that is automatically installed on nodes.""" + network_interfaces: Dict[str, NIC] = {} "The Network Interfaces on the node." network_interface: Dict[int, NIC] = {} @@ -321,23 +323,8 @@ class HostNode(Node): This method equips the host with essential network services and applications, preparing it for various network-related tasks and operations. """ - # ARP Service - self.software_manager.install(HostARP) - - # ICMP Service - self.software_manager.install(ICMP) - - # DNS Client - self.software_manager.install(DNSClient) - - # FTP Client - self.software_manager.install(FTPClient) - - # NTP Client - self.software_manager.install(NTPClient) - - # Web Browser - self.software_manager.install(WebBrowser) + for _, software_class in self.SYSTEM_SOFTWARE.items(): + self.software_manager.install(software_class) super()._install_system_software() diff --git a/src/primaite/simulator/network/hardware/nodes/network/firewall.py b/src/primaite/simulator/network/hardware/nodes/network/firewall.py index 22effa2a..d7b1dfd9 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/firewall.py +++ b/src/primaite/simulator/network/hardware/nodes/network/firewall.py @@ -1,8 +1,10 @@ +from ipaddress import IPv4Address from typing import Dict, Final, Optional, Union from prettytable import MARKDOWN, PrettyTable from pydantic import validate_call +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState from primaite.simulator.network.hardware.nodes.network.router import ( AccessControlList, ACLAction, @@ -10,6 +12,8 @@ from primaite.simulator.network.hardware.nodes.network.router import ( RouterInterface, ) from primaite.simulator.network.transmission.data_link_layer import Frame +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog from primaite.utils.validators import IPV4Address @@ -85,7 +89,17 @@ class Firewall(Router): if not kwargs.get("sys_log"): kwargs["sys_log"] = SysLog(hostname) - super().__init__(hostname=hostname, num_ports=3, **kwargs) + super().__init__(hostname=hostname, num_ports=0, **kwargs) + + self.connect_nic( + RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="external") + ) + self.connect_nic( + RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="internal") + ) + self.connect_nic( + RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0", port_name="dmz") + ) # Initialise ACLs for internal and dmz interfaces with a default DENY policy self.internal_inbound_acl = AccessControlList( @@ -109,24 +123,6 @@ class Firewall(Router): sys_log=kwargs["sys_log"], implicit_action=ACLAction.PERMIT, name=f"{hostname} - External Outbound" ) - self.set_original_state() - - def set_original_state(self): - """Set the original state for the Firewall.""" - super().set_original_state() - vals_to_include = { - "internal_port", - "external_port", - "dmz_port", - "internal_inbound_acl", - "internal_outbound_acl", - "dmz_inbound_acl", - "dmz_outbound_acl", - "external_inbound_acl", - "external_outbound_acl", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - def describe_state(self) -> Dict: """ Describes the current state of the Firewall. @@ -491,3 +487,124 @@ class Firewall(Router): """ self.configure_port(DMZ_PORT_ID, ip_address, subnet_mask) self.dmz_port.enable() + + @classmethod + def from_config(cls, cfg: dict) -> "Firewall": + """Create a firewall based on a config dict.""" + firewall = Firewall( + hostname=cfg["hostname"], + operating_state=NodeOperatingState.ON + if not (p := cfg.get("operating_state")) + else NodeOperatingState[p.upper()], + ) + if "ports" in cfg: + internal_port = cfg["ports"]["internal_port"] + external_port = cfg["ports"]["external_port"] + dmz_port = cfg["ports"].get("dmz_port") + + # configure internal port + firewall.configure_internal_port( + ip_address=IPV4Address(internal_port.get("ip_address")), + subnet_mask=IPV4Address(internal_port.get("subnet_mask", "255.255.255.0")), + ) + + # configure external port + firewall.configure_external_port( + ip_address=IPV4Address(external_port.get("ip_address")), + subnet_mask=IPV4Address(external_port.get("subnet_mask", "255.255.255.0")), + ) + + # configure dmz port if not none + if dmz_port is not None: + firewall.configure_dmz_port( + ip_address=IPV4Address(dmz_port.get("ip_address")), + subnet_mask=IPV4Address(dmz_port.get("subnet_mask", "255.255.255.0")), + ) + if "acl" in cfg: + # acl rules for internal_inbound_acl + if cfg["acl"]["internal_inbound_acl"]: + for r_num, r_cfg in cfg["acl"]["internal_inbound_acl"].items(): + firewall.internal_inbound_acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + + # acl rules for internal_outbound_acl + if cfg["acl"]["internal_outbound_acl"]: + for r_num, r_cfg in cfg["acl"]["internal_outbound_acl"].items(): + firewall.internal_outbound_acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + + # acl rules for dmz_inbound_acl + if cfg["acl"]["dmz_inbound_acl"]: + for r_num, r_cfg in cfg["acl"]["dmz_inbound_acl"].items(): + firewall.dmz_inbound_acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + + # acl rules for dmz_outbound_acl + if cfg["acl"]["dmz_outbound_acl"]: + for r_num, r_cfg in cfg["acl"]["dmz_outbound_acl"].items(): + firewall.dmz_outbound_acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + + # acl rules for external_inbound_acl + if cfg["acl"].get("external_inbound_acl"): + for r_num, r_cfg in cfg["acl"]["external_inbound_acl"].items(): + firewall.external_inbound_acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + + # acl rules for external_outbound_acl + if cfg["acl"].get("external_outbound_acl"): + for r_num, r_cfg in cfg["acl"]["external_outbound_acl"].items(): + firewall.external_outbound_acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + + if "routes" in cfg: + for route in cfg.get("routes"): + firewall.route_table.add_route( + address=IPv4Address(route.get("address")), + subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), + next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), + metric=float(route.get("metric", 0)), + ) + return firewall diff --git a/src/primaite/simulator/network/hardware/nodes/network/router.py b/src/primaite/simulator/network/hardware/nodes/network/router.py index 774aae7c..d2b47c1a 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/router.py @@ -8,6 +8,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union from prettytable import MARKDOWN, PrettyTable from pydantic import validate_call +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.network.hardware.base import IPWiredNetworkInterface from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -136,11 +137,6 @@ class ACLRule(SimComponent): rule_strings.append(f"{key}={value}") return ", ".join(rule_strings) - def set_original_state(self): - """Sets the original state.""" - vals_to_keep = {"action", "protocol", "src_ip_address", "src_port", "dst_ip_address", "dst_port"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - def describe_state(self) -> Dict: """ Describes the current state of the ACLRule. @@ -296,50 +292,13 @@ class AccessControlList(SimComponent): super().__init__(**kwargs) self._acl = [None] * (self.max_acl_rules - 1) - self.set_original_state() - - def set_original_state(self): - """Sets the original state.""" - self.implicit_rule.set_original_state() - vals_to_keep = {"implicit_action", "max_acl_rules", "acl"} - self._original_state = self.model_dump(include=vals_to_keep, exclude_none=True) - - for i, rule in enumerate(self._acl): - if not rule: - continue - self._default_config[i] = {"action": rule.action.name} - if rule.src_ip_address: - self._default_config[i]["src_ip"] = str(rule.src_ip_address) - if rule.dst_ip_address: - self._default_config[i]["dst_ip"] = str(rule.dst_ip_address) - if rule.src_port: - self._default_config[i]["src_port"] = rule.src_port.name - if rule.dst_port: - self._default_config[i]["dst_port"] = rule.dst_port.name - if rule.protocol: - self._default_config[i]["protocol"] = rule.protocol.name - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.implicit_rule.reset_component_for_episode(episode) - super().reset_component_for_episode(episode) - self._reset_rules_to_default() - - def _reset_rules_to_default(self) -> None: - """Clear all ACL rules and set them to the default rules config.""" - self._acl = [None] * (self.max_acl_rules - 1) - for r_num, r_cfg in self._default_config.items(): - self.add_rule( - action=ACLAction[r_cfg["action"]], - src_port=None if not (p := r_cfg.get("src_port")) else Port[p], - dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], - protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], - src_ip_address=r_cfg.get("src_ip"), - dst_ip_address=r_cfg.get("dst_ip"), - position=r_num, - ) def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ # TODO: Add src and dst wildcard masks as positional args in this request. rm = super()._init_request_manager() @@ -355,19 +314,24 @@ class AccessControlList(SimComponent): rm.add_request( "add_rule", RequestType( - func=lambda request, context: self.add_rule( - action=ACLAction[request[0]], - protocol=None if request[1] == "ALL" else IPProtocol[request[1]], - src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]), - src_port=None if request[3] == "ALL" else Port[request[3]], - dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]), - dst_port=None if request[5] == "ALL" else Port[request[5]], - position=int(request[6]), + func=lambda request, context: RequestResponse.from_bool( + self.add_rule( + action=ACLAction[request[0]], + protocol=None if request[1] == "ALL" else IPProtocol[request[1]], + src_ip_address=None if request[2] == "ALL" else IPv4Address(request[2]), + src_port=None if request[3] == "ALL" else Port[request[3]], + dst_ip_address=None if request[4] == "ALL" else IPv4Address(request[4]), + dst_port=None if request[5] == "ALL" else Port[request[5]], + position=int(request[6]), + ) ) ), ) - rm.add_request("remove_rule", RequestType(func=lambda request, context: self.remove_rule(int(request[0])))) + rm.add_request( + "remove_rule", + RequestType(func=lambda request, context: RequestResponse.from_bool(self.remove_rule(int(request[0])))), + ) return rm def describe_state(self) -> Dict: @@ -413,7 +377,7 @@ class AccessControlList(SimComponent): src_port: Optional[Port] = None, dst_port: Optional[Port] = None, position: int = 0, - ) -> None: + ) -> bool: """ Adds a new ACL rule to control network traffic based on specified criteria. @@ -470,10 +434,12 @@ class AccessControlList(SimComponent): src_port=src_port, dst_port=dst_port, ) + return True else: raise ValueError(f"Cannot add ACL rule, position {position} is out of bounds.") + return False - def remove_rule(self, position: int) -> None: + def remove_rule(self, position: int) -> bool: """ Remove an ACL rule from a specific position. @@ -484,8 +450,10 @@ class AccessControlList(SimComponent): rule = self._acl[position] # noqa self._acl[position] = None del rule + return True else: raise ValueError(f"Cannot remove ACL rule, position {position} is out of bounds.") + return False def is_permitted(self, frame: Frame) -> Tuple[bool, ACLRule]: """Check if a packet with the given properties is permitted through the ACL.""" @@ -616,11 +584,6 @@ class RouteEntry(SimComponent): metric: float = 0.0 "The cost metric for this route. Default is 0.0." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"address", "subnet_mask", "next_hop_ip_address", "metric"} - self._original_values = self.model_dump(include=vals_to_include) - def describe_state(self) -> Dict: """ Describes the current state of the RouteEntry. @@ -653,17 +616,6 @@ class RouteTable(SimComponent): default_route: Optional[RouteEntry] = None sys_log: SysLog - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - self._original_state["routes_orig"] = self.routes - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.routes.clear() - self.routes = self._original_state["routes_orig"] - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the RouteTable. @@ -1061,7 +1013,7 @@ class RouterInterface(IPWiredNetworkInterface): :return: A string combining the port number, MAC address and IP address of the NIC. """ - return f"Port {self.port_num}: {self.mac_address}/{self.ip_address}" + return f"Port {self.port_name if self.port_name else self.port_num}: {self.mac_address}/{self.ip_address}" class Router(NetworkNode): @@ -1104,8 +1056,6 @@ class Router(NetworkNode): self._set_default_acl() - self.set_original_state() - def _install_system_software(self): """ Installs essential system software and network services on the router. @@ -1131,20 +1081,7 @@ class Router(NetworkNode): self.acl.add_rule(action=ACLAction.PERMIT, src_port=Port.ARP, dst_port=Port.ARP, position=22) self.acl.add_rule(action=ACLAction.PERMIT, protocol=IPProtocol.ICMP, position=23) - def set_original_state(self): - """ - Sets or resets the router to its original configuration state. - - This method is called to initialize the router's state or to revert it to a known good configuration during - network simulations or after configuration changes. - """ - self.acl.set_original_state() - self.route_table.set_original_state() - super().set_original_state() - vals_to_include = {"num_ports"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """ Resets the router's components for a new network simulation episode. @@ -1154,15 +1091,17 @@ class Router(NetworkNode): :param episode: The episode number for which the router is being reset. """ self.software_manager.arp.clear() - self.acl.reset_component_for_episode(episode) - self.route_table.reset_component_for_episode(episode) - for i, network_interface in self.network_interface.items(): - network_interface.reset_component_for_episode(episode) + for i, _ in self.network_interface.items(): self.enable_port(i) - super().reset_component_for_episode(episode) + super().setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() rm.add_request("acl", RequestType(func=self.acl._request_manager)) return rm @@ -1391,7 +1330,6 @@ class Router(NetworkNode): network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask self.sys_log.info(f"Configured Network Interface {network_interface}") - self.set_original_state() def enable_port(self, port: int): """ @@ -1480,19 +1418,37 @@ class Router(NetworkNode): :return: Configured router. :rtype: Router """ - new = Router( + router = Router( hostname=cfg["hostname"], - num_ports=cfg.get("num_ports"), - operating_state=NodeOperatingState.ON, + num_ports=int(cfg.get("num_ports", "5")), + operating_state=NodeOperatingState.ON + if not (p := cfg.get("operating_state")) + else NodeOperatingState[p.upper()], ) if "ports" in cfg: for port_num, port_cfg in cfg["ports"].items(): - new.configure_port( + router.configure_port( port=port_num, ip_address=port_cfg["ip_address"], - subnet_mask=port_cfg["subnet_mask"], + subnet_mask=IPv4Address(port_cfg.get("subnet_mask", "255.255.255.0")), ) if "acl" in cfg: - new.acl._default_config = cfg["acl"] # save the config to allow resetting - new.acl._reset_rules_to_default() # read the config and apply rules - return new + for r_num, r_cfg in cfg["acl"].items(): + router.acl.add_rule( + action=ACLAction[r_cfg["action"]], + src_port=None if not (p := r_cfg.get("src_port")) else Port[p], + dst_port=None if not (p := r_cfg.get("dst_port")) else Port[p], + protocol=None if not (p := r_cfg.get("protocol")) else IPProtocol[p], + src_ip_address=r_cfg.get("src_ip"), + dst_ip_address=r_cfg.get("dst_ip"), + position=r_num, + ) + if "routes" in cfg: + for route in cfg.get("routes"): + router.route_table.add_route( + address=IPv4Address(route.get("address")), + subnet_mask=IPv4Address(route.get("subnet_mask", "255.255.255.0")), + next_hop_ip_address=IPv4Address(route.get("next_hop_ip_address")), + metric=float(route.get("metric", 0)), + ) + return router diff --git a/src/primaite/simulator/network/hardware/nodes/network/switch.py b/src/primaite/simulator/network/hardware/nodes/network/switch.py index 33e6ee9a..557ea287 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/switch.py +++ b/src/primaite/simulator/network/hardware/nodes/network/switch.py @@ -32,12 +32,6 @@ class SwitchPort(WiredNetworkInterface): _connected_node: Optional[Switch] = None "The Switch to which the SwitchPort is connected." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = {"port_num", "mac_address", "speed", "mtu", "enabled"} - self._original_state = self.model_dump(include=vals_to_include) - super().set_original_state() - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. diff --git a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py index dd0b58d3..3e8d715f 100644 --- a/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py +++ b/src/primaite/simulator/network/hardware/nodes/network/wireless_router.py @@ -80,7 +80,10 @@ class WirelessAccessPoint(IPWirelessNetworkInterface): :return: A string combining the port number, MAC address and IP address of the NIC. """ - return f"Port {self.port_num}: {self.mac_address}/{self.ip_address} ({self.frequency})" + return ( + f"Port {self.port_name if self.port_name else self.port_num}: " + f"{self.mac_address}/{self.ip_address} ({self.frequency})" + ) class WirelessRouter(Router): @@ -122,8 +125,6 @@ class WirelessRouter(Router): self.connect_nic(RouterInterface(ip_address="127.0.0.1", subnet_mask="255.0.0.0", gateway="0.0.0.0")) - self.set_original_state() - @property def wireless_access_point(self) -> WirelessAccessPoint: """ @@ -166,7 +167,6 @@ class WirelessRouter(Router): network_interface.ip_address = ip_address network_interface.subnet_mask = subnet_mask self.sys_log.info(f"Configured WAP {network_interface}") - self.set_original_state() self.wireless_access_point.frequency = frequency # Set operating frequency self.wireless_access_point.enable() # Re-enable the WAP with new settings diff --git a/src/primaite/simulator/network/networks.py b/src/primaite/simulator/network/networks.py index f82dee4a..c1eef224 100644 --- a/src/primaite/simulator/network/networks.py +++ b/src/primaite/simulator/network/networks.py @@ -146,6 +146,12 @@ def arcd_uc2_network() -> Network: ) client_1.power_on() network.connect(endpoint_b=client_1.network_interface[1], endpoint_a=switch_2.network_interface[1]) + client_1.software_manager.install(DatabaseClient) + db_client_1: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client_1.configure(server_ip_address=IPv4Address("192.168.1.14")) + db_client_1.run() + web_browser_1 = client_1.software_manager.software.get("WebBrowser") + web_browser_1.target_url = "http://arcd.com/users/" client_1.software_manager.install(DataManipulationBot) db_manipulation_bot: DataManipulationBot = client_1.software_manager.software.get("DataManipulationBot") db_manipulation_bot.configure( @@ -165,8 +171,12 @@ def arcd_uc2_network() -> Network: start_up_duration=0, ) client_2.power_on() - web_browser = client_2.software_manager.software.get("WebBrowser") - web_browser.target_url = "http://arcd.com/users/" + client_2.software_manager.install(DatabaseClient) + db_client_2 = client_2.software_manager.software.get("DatabaseClient") + db_client_2.configure(server_ip_address=IPv4Address("192.168.1.14")) + db_client_2.run() + web_browser_2 = client_2.software_manager.software.get("WebBrowser") + web_browser_2.target_url = "http://arcd.com/users/" network.connect(endpoint_b=client_2.network_interface[1], endpoint_a=switch_2.network_interface[2]) # Domain Controller @@ -194,67 +204,10 @@ def arcd_uc2_network() -> Network: database_server.power_on() network.connect(endpoint_b=database_server.network_interface[1], endpoint_a=switch_1.network_interface[3]) - ddl = """ - CREATE TABLE IF NOT EXISTS user ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name VARCHAR(50) NOT NULL, - email VARCHAR(50) NOT NULL, - age INT, - city VARCHAR(50), - occupation VARCHAR(50) - );""" - - user_insert_statements = [ - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('John Doe', 'johndoe@example.com', 32, 'New York', 'Engineer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Jane Smith', 'janesmith@example.com', 27, 'Los Angeles', 'Designer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Bob Johnson', 'bobjohnson@example.com', 45, 'Chicago', 'Manager');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Alice Lee', 'alicelee@example.com', 22, 'San Francisco', 'Student');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('David Kim', 'davidkim@example.com', 38, 'Houston', 'Consultant');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Emily Chen', 'emilychen@example.com', 29, 'Seattle', 'Software Developer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Frank Wang', 'frankwang@example.com', 55, 'New York', 'Entrepreneur');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Grace Park', 'gracepark@example.com', 31, 'Los Angeles', 'Marketing Specialist');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Henry Wu', 'henrywu@example.com', 40, 'Chicago', 'Accountant');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Isabella Kim', 'isabellakim@example.com', 26, 'San Francisco', 'Graphic Designer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Jake Lee', 'jakelee@example.com', 33, 'Houston', 'Sales Manager');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Kelly Chen', 'kellychen@example.com', 28, 'Seattle', 'Web Developer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Lucas Liu', 'lucasliu@example.com', 42, 'New York', 'Lawyer');", - # noqa - "INSERT INTO user (name, email, age, city, occupation) " - "VALUES ('Maggie Wang', 'maggiewang@example.com', 30, 'Los Angeles', 'Data Analyst');", - # noqa - ] database_server.software_manager.install(DatabaseService) database_service: DatabaseService = database_server.software_manager.software.get("DatabaseService") # noqa database_service.start() database_service.configure_backup(backup_server=IPv4Address("192.168.1.16")) - database_service._process_sql(ddl, None, None) # noqa - for insert_statement in user_insert_statements: - database_service._process_sql(insert_statement, None, None) # noqa # Web Server web_server = Server( diff --git a/src/primaite/simulator/network/nmne.py b/src/primaite/simulator/network/nmne.py new file mode 100644 index 00000000..87839712 --- /dev/null +++ b/src/primaite/simulator/network/nmne.py @@ -0,0 +1,47 @@ +from typing import Dict, Final, List + +CAPTURE_NMNE: bool = True +"""Indicates whether Malicious Network Events (MNEs) should be captured. Default is True.""" + +NMNE_CAPTURE_KEYWORDS: List[str] = [] +"""List of keywords to identify malicious network events.""" + +# TODO: Remove final and make configurable after example layout when the NicObservation creates nmne structure dynamically +CAPTURE_BY_DIRECTION: Final[bool] = True +"""Flag to determine if captures should be organized by traffic direction (inbound/outbound).""" +CAPTURE_BY_IP_ADDRESS: Final[bool] = False +"""Flag to determine if captures should be organized by source or destination IP address.""" +CAPTURE_BY_PROTOCOL: Final[bool] = False +"""Flag to determine if captures should be organized by network protocol (e.g., TCP, UDP).""" +CAPTURE_BY_PORT: Final[bool] = False +"""Flag to determine if captures should be organized by source or destination port.""" +CAPTURE_BY_KEYWORD: Final[bool] = False +"""Flag to determine if captures should be filtered and categorised based on specific keywords.""" + + +def set_nmne_config(nmne_config: Dict): + """ + Sets the configuration for capturing Malicious Network Events (MNEs) based on a provided dictionary. + + This function updates global settings related to NMNE capture, including whether to capture NMNEs and what + keywords to use for identifying NMNEs. + + The function ensures that the settings are updated only if they are provided in the `nmne_config` dictionary, + and maintains type integrity by checking the types of the provided values. + + :param nmne_config: A dictionary containing the NMNE configuration settings. Possible keys include: + "capture_nmne" (bool) to indicate whether NMNEs should be captured, "nmne_capture_keywords" (list of strings) + to specify keywords for NMNE identification. + """ + global NMNE_CAPTURE_KEYWORDS + global CAPTURE_NMNE + + # Update the NMNE capture flag, defaulting to False if not specified or if the type is incorrect + CAPTURE_NMNE = nmne_config.get("capture_nmne", False) + if not isinstance(CAPTURE_NMNE, bool): + CAPTURE_NMNE = True # Revert to default True if the provided value is not a boolean + + # Update the NMNE capture keywords, appending new keywords if provided + NMNE_CAPTURE_KEYWORDS += nmne_config.get("nmne_capture_keywords", []) + if not isinstance(NMNE_CAPTURE_KEYWORDS, list): + NMNE_CAPTURE_KEYWORDS = [] # Reset to empty list if the provided value is not a list diff --git a/src/primaite/simulator/network/transmission/network_layer.py b/src/primaite/simulator/network/transmission/network_layer.py index bdf4babc..8ee0b4af 100644 --- a/src/primaite/simulator/network/transmission/network_layer.py +++ b/src/primaite/simulator/network/transmission/network_layer.py @@ -9,11 +9,20 @@ _LOGGER = getLogger(__name__) class IPProtocol(Enum): - """Enum representing transport layer protocols in IP header.""" + """ + Enum representing transport layer protocols in IP header. + .. _List of IPProtocols: + """ + + NONE = "none" + """Placeholder for a non-protocol.""" TCP = "tcp" + """Transmission Control Protocol.""" UDP = "udp" + """User Datagram Protocol.""" ICMP = "icmp" + """Internet Control Message Protocol.""" class Precedence(Enum): diff --git a/src/primaite/simulator/network/transmission/transport_layer.py b/src/primaite/simulator/network/transmission/transport_layer.py index 7c7509ab..c73e451a 100644 --- a/src/primaite/simulator/network/transmission/transport_layer.py +++ b/src/primaite/simulator/network/transmission/transport_layer.py @@ -5,7 +5,11 @@ from pydantic import BaseModel class Port(Enum): - """Enumeration of common known TCP/UDP ports used by protocols for operation of network applications.""" + """ + Enumeration of common known TCP/UDP ports used by protocols for operation of network applications. + + .. _List of Ports: + """ NONE = 0 "Place holder for a non-port." diff --git a/src/primaite/simulator/sim_container.py b/src/primaite/simulator/sim_container.py index 896861e6..997cc0be 100644 --- a/src/primaite/simulator/sim_container.py +++ b/src/primaite/simulator/sim_container.py @@ -1,5 +1,6 @@ from typing import Dict +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType, SimComponent from primaite.simulator.domain.controller import DomainController from primaite.simulator.network.container import Network @@ -21,21 +22,23 @@ class Simulation(SimComponent): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - self.network.set_original_state() - - def reset_component_for_episode(self, episode: int): + def setup_for_episode(self, episode: int): """Reset the original state of the SimComponent.""" - self.network.reset_component_for_episode(episode) + self.network.setup_for_episode(episode=episode) def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() # pass through network requests to the network objects rm.add_request("network", RequestType(func=self.network._request_manager)) # pass through domain requests to the domain object rm.add_request("domain", RequestType(func=self.domain._request_manager)) - rm.add_request("do_nothing", RequestType(func=lambda request, context: ())) + # if 'do_nothing' is requested, just return a success + rm.add_request("do_nothing", RequestType(func=lambda request, context: RequestResponse(status="success"))) return rm def describe_state(self) -> Dict: diff --git a/src/primaite/simulator/system/applications/application.py b/src/primaite/simulator/system/applications/application.py index 322ac808..b7422680 100644 --- a/src/primaite/simulator/system/applications/application.py +++ b/src/primaite/simulator/system/applications/application.py @@ -38,12 +38,6 @@ class Application(IOSoftware): def __init__(self, **kwargs): super().__init__(**kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "execution_control_status", "num_executions", "groups"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ @@ -65,6 +59,16 @@ class Application(IOSoftware): ) return state + def apply_timestep(self, timestep: int) -> None: + """ + Apply a timestep to the application. + + :param timestep: The current timestep of the simulation. + """ + super().apply_timestep(timestep=timestep) + + self.num_executions = 0 # reset number of executions + def _can_perform_action(self) -> bool: """ Checks if the application can perform actions. @@ -79,7 +83,7 @@ class Application(IOSoftware): if self.operating_state is not self.operating_state.RUNNING: # service is not running - _LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}") + _LOGGER.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}") return False return True diff --git a/src/primaite/simulator/system/applications/database_client.py b/src/primaite/simulator/system/applications/database_client.py index 5805ed43..d3afef59 100644 --- a/src/primaite/simulator/system/applications/database_client.py +++ b/src/primaite/simulator/system/applications/database_client.py @@ -3,6 +3,8 @@ from typing import Any, Dict, Optional from uuid import uuid4 from primaite import getLogger +from primaite.interface.request import RequestResponse +from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.applications.application import Application @@ -25,26 +27,34 @@ class DatabaseClient(Application): server_password: Optional[str] = None connected: bool = False _query_success_tracker: Dict[str, bool] = {} + _last_connection_successful: Optional[bool] = None + """Keep track of connections that were established or verified during this step. Used for rewards.""" def __init__(self, **kwargs): kwargs["name"] = "DatabaseClient" kwargs["port"] = Port.POSTGRES_SERVER kwargs["protocol"] = IPProtocol.TCP super().__init__(**kwargs) - self.set_original_state() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseClient WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_ip_address", "server_password", "connected", "_query_success_tracker"} - self._original_state.update(self.model_dump(include=vals_to_include)) + def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataBaseClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - self._query_success_tracker.clear() + More information in user guide and docstring for SimComponent._init_request_manager. + """ + rm = super()._init_request_manager() + rm.add_request("execute", RequestType(func=lambda request, context: RequestResponse.from_bool(self.execute()))) + return rm + + def execute(self) -> bool: + """Execution definition for db client: perform a select query.""" + self.num_executions += 1 # trying to connect counts as an execution + if self.connections: + can_connect = self.check_connection(connection_id=list(self.connections.keys())[-1]) + else: + can_connect = self.check_connection(connection_id=str(uuid4())) + self._last_connection_successful = can_connect + return can_connect def describe_state(self) -> Dict: """ @@ -52,8 +62,10 @@ class DatabaseClient(Application): :return: A dictionary representing the current state. """ - pass - return super().describe_state() + state = super().describe_state() + # list of connections that were established or verified during this step. + state["last_connection_successful"] = self._last_connection_successful + return state def configure(self, server_ip_address: IPv4Address, server_password: Optional[str] = None): """ @@ -79,6 +91,18 @@ class DatabaseClient(Application): ) return self.connected + def check_connection(self, connection_id: str) -> bool: + """Check whether the connection can be successfully re-established. + + :param connection_id: connection ID to check + :type connection_id: str + :return: Whether the connection was successfully re-established. + :rtype: bool + """ + if not self._can_perform_action(): + return False + return self.query("SELECT * FROM pg_stat_activity", connection_id=connection_id) + def _connect( self, server_ip_address: IPv4Address, @@ -196,7 +220,11 @@ class DatabaseClient(Application): return False if connection_id is None: - connection_id = str(uuid4()) + if self.connections: + connection_id = list(self.connections.keys())[-1] + # TODO: if the most recent connection dies, it should be automatically cleared. + else: + connection_id = str(uuid4()) if not self.connections.get(connection_id): if not self.connect(connection_id=connection_id): diff --git a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py index a844f059..ee276971 100644 --- a/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/data_manipulation_bot.py @@ -1,10 +1,14 @@ from enum import IntEnum from ipaddress import IPv4Address -from typing import Optional +from typing import Dict, Optional from primaite import getLogger from primaite.game.science import simulate_trial +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -32,12 +36,10 @@ class DataManipulationAttackStage(IntEnum): "Signifies that the attack has failed." -class DataManipulationBot(DatabaseClient): +class DataManipulationBot(Application): """A bot that simulates a script which performs a SQL injection attack.""" - server_ip_address: Optional[IPv4Address] = None payload: Optional[str] = None - server_password: Optional[str] = None port_scan_p_of_success: float = 0.1 data_manipulation_p_of_success: float = 0.1 @@ -46,33 +48,44 @@ class DataManipulationBot(DatabaseClient): "Whether to repeat attacking once finished." def __init__(self, **kwargs): + kwargs["name"] = "DataManipulationBot" + kwargs["port"] = Port.NONE + kwargs["protocol"] = IPProtocol.NONE + super().__init__(**kwargs) - self.name = "DataManipulationBot" - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DataManipulationBot original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "server_ip_address", - "payload", - "server_password", - "port_scan_p_of_success", - "data_manipulation_p_of_success", - "attack_stage", - "repeat", - } - self._original_state.update(self.model_dump(include=vals_to_include)) + def describe_state(self) -> Dict: + """ + Produce a dictionary describing the current state of this object. - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting DataManipulationBot state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) + Please see :py:meth:`primaite.simulator.core.SimComponent.describe_state` for a more detailed explanation. + + :return: Current state of this object and child objects. + :rtype: Dict + """ + state = super().describe_state() + return state + + @property + def _host_db_client(self) -> DatabaseClient: + """Return the database client that is installed on the same machine as the DataManipulationBot.""" + db_client = self.software_manager.software.get("DatabaseClient") + if db_client is None: + _LOGGER.info(f"{self.__class__.__name__} cannot find a database client on its host.") + return db_client def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() - rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.attack())) + rm.add_request( + name="execute", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.attack())), + ) return rm @@ -96,8 +109,8 @@ class DataManipulationBot(DatabaseClient): :param repeat: Whether to repeat attacking once finished. """ self.server_ip_address = server_ip_address - self.payload = payload self.server_password = server_password + self.payload = payload self.port_scan_p_of_success = port_scan_p_of_success self.data_manipulation_p_of_success = data_manipulation_p_of_success self.repeat = repeat @@ -143,15 +156,21 @@ class DataManipulationBot(DatabaseClient): :param p_of_success: Probability of successfully performing data manipulation, by default 0.1. """ + if self._host_db_client is None: + self.attack_stage = DataManipulationAttackStage.FAILED + return + + self._host_db_client.server_ip_address = self.server_ip_address + self._host_db_client.server_password = self.server_password if self.attack_stage == DataManipulationAttackStage.PORT_SCAN: # perform the actual data manipulation attack if simulate_trial(p_of_success): self.sys_log.info(f"{self.name}: Performing data manipulation") # perform the attack - if not len(self.connections): - self.connect() - if len(self.connections): - self.query(self.payload) + if not len(self._host_db_client.connections): + self._host_db_client.connect() + if len(self._host_db_client.connections): + self._host_db_client.query(self.payload) self.sys_log.info(f"{self.name} payload delivered: {self.payload}") attack_successful = True if attack_successful: @@ -169,21 +188,23 @@ class DataManipulationBot(DatabaseClient): """ super().run() - def attack(self): + def attack(self) -> bool: """Perform the attack steps after opening the application.""" if not self._can_perform_action(): _LOGGER.debug("Data manipulation application attempted to execute but it cannot perform actions right now.") self.run() - self._application_loop() - def _application_loop(self): + self.num_executions += 1 + return self._application_loop() + + def _application_loop(self) -> bool: """ The main application loop of the bot, handling the attack process. This is the core loop where the bot sequentially goes through the stages of the attack. """ if not self._can_perform_action(): - return + return False if self.server_ip_address and self.payload: self.sys_log.info(f"{self.name}: Running") self._logon() @@ -195,8 +216,12 @@ class DataManipulationBot(DatabaseClient): DataManipulationAttackStage.FAILED, ): self.attack_stage = DataManipulationAttackStage.NOT_STARTED + + return True + else: self.sys_log.error(f"{self.name}: Failed to start as it requires both a target_ip_address and payload.") + return False def apply_timestep(self, timestep: int) -> None: """ diff --git a/src/primaite/simulator/system/applications/red_applications/dos_bot.py b/src/primaite/simulator/system/applications/red_applications/dos_bot.py index dfc48dd3..27a4da05 100644 --- a/src/primaite/simulator/system/applications/red_applications/dos_bot.py +++ b/src/primaite/simulator/system/applications/red_applications/dos_bot.py @@ -4,9 +4,9 @@ from typing import Optional from primaite import getLogger from primaite.game.science import simulate_trial +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.system.applications.application import Application from primaite.simulator.system.applications.database_client import DatabaseClient _LOGGER = getLogger(__name__) @@ -28,7 +28,7 @@ class DoSAttackStage(IntEnum): "Attack is completed." -class DoSBot(DatabaseClient, Application): +class DoSBot(DatabaseClient): """A bot that simulates a Denial of Service attack.""" target_ip_address: Optional[IPv4Address] = None @@ -57,31 +57,18 @@ class DoSBot(DatabaseClient, Application): self.name = "DoSBot" self.max_sessions = 1000 # override normal max sessions - def set_original_state(self): - """Set the original state of the Denial of Service Bot.""" - _LOGGER.debug(f"Setting {self.name} original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "target_ip_address", - "target_port", - "payload", - "repeat", - "attack_stage", - "max_sessions", - "port_scan_p_of_success", - "dos_intensity", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting {self.name} state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() - rm.add_request(name="execute", request_type=RequestType(func=lambda request, context: self.run())) + rm.add_request( + name="execute", + request_type=RequestType(func=lambda request, context: RequestResponse.from_bool(self.run())), + ) return rm @@ -119,26 +106,26 @@ class DoSBot(DatabaseClient, Application): f"{repeat=}, {port_scan_p_of_success=}, {dos_intensity=}, {max_sessions=}." ) - def run(self): + def run(self) -> bool: """Run the Denial of Service Bot.""" super().run() - self._application_loop() + return self._application_loop() - def _application_loop(self): + def _application_loop(self) -> bool: """ The main application loop for the Denial of Service bot. The loop goes through the stages of a DoS attack. """ if not self._can_perform_action(): - return + return False # DoS bot cannot do anything without a target if not self.target_ip_address or not self.target_port: self.sys_log.error( f"{self.name} is not properly configured. {self.target_ip_address=}, {self.target_port=}" ) - return + return True self.clear_connections() self._perform_port_scan(p_of_success=self.port_scan_p_of_success) @@ -148,6 +135,7 @@ class DoSBot(DatabaseClient, Application): self.attack_stage = DoSAttackStage.NOT_STARTED else: self.attack_stage = DoSAttackStage.COMPLETED + return True def _perform_port_scan(self, p_of_success: Optional[float] = 0.1): """ diff --git a/src/primaite/simulator/system/applications/web_browser.py b/src/primaite/simulator/system/applications/web_browser.py index eef0ed5d..e669ca32 100644 --- a/src/primaite/simulator/system/applications/web_browser.py +++ b/src/primaite/simulator/system/applications/web_browser.py @@ -6,6 +6,7 @@ from urllib.parse import urlparse from pydantic import BaseModel, ConfigDict from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.network.protocols.http import ( HttpRequestMethod, @@ -47,25 +48,20 @@ class WebBrowser(Application): kwargs["port"] = Port.HTTP super().__init__(**kwargs) - self.set_original_state() self.run() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebBrowser original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"target_url", "domain_name_ip_address", "latest_response"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebBrowser state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() rm.add_request( - name="execute", request_type=RequestType(func=lambda request, context: self.get_webpage()) # noqa + name="execute", + request_type=RequestType( + func=lambda request, context: RequestResponse.from_bool(self.get_webpage()) + ), # noqa ) return rm @@ -80,9 +76,6 @@ class WebBrowser(Application): state["history"] = [hist_item.state() for hist_item in self.history] return state - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - def get_webpage(self, url: Optional[str] = None) -> bool: """ Retrieve the webpage. @@ -96,6 +89,8 @@ class WebBrowser(Application): if not self._can_perform_action(): return False + self.num_executions += 1 # trying to connect counts as an execution + # reset latest response self.latest_response = HttpResponsePacket(status_code=HttpStatusCode.NOT_FOUND) @@ -215,7 +210,7 @@ class WebBrowser(Application): def state(self) -> Dict: """Return the contents of this dataclass as a dict for use with describe_state method.""" if self.status == self._HistoryItemStatus.LOADED: - outcome = self.response_code + outcome = self.response_code.value else: outcome = self.status.value return {"url": self.url, "outcome": outcome} diff --git a/src/primaite/simulator/system/core/packet_capture.py b/src/primaite/simulator/system/core/packet_capture.py index fb8a1624..cf38e94b 100644 --- a/src/primaite/simulator/system/core/packet_capture.py +++ b/src/primaite/simulator/system/core/packet_capture.py @@ -21,7 +21,13 @@ class PacketCapture: The PCAPs are logged to: //__pcap.log """ - def __init__(self, hostname: str, ip_address: Optional[str] = None, interface_num: Optional[int] = None): + def __init__( + self, + hostname: str, + ip_address: Optional[str] = None, + port_num: Optional[int] = None, + port_name: Optional[str] = None, + ): """ Initialize the PacketCapture process. @@ -32,16 +38,20 @@ class PacketCapture: "The hostname for which PCAP logs are being recorded." self.ip_address: str = ip_address "The IP address associated with the PCAP logs." - self.interface_num = interface_num + self.port_num = port_num "The interface num on the Node." + self.port_name = port_name + "The interface name on the Node." + self.inbound_logger = None self.outbound_logger = None self.current_episode: int = 1 - self.setup_logger(outbound=False) - self.setup_logger(outbound=True) + if SIM_OUTPUT.save_pcap_logs: + self.setup_logger(outbound=False) + self.setup_logger(outbound=True) def setup_logger(self, outbound: bool = False): """Set up the logger configuration.""" @@ -79,10 +89,12 @@ class PacketCapture: def _get_logger_name(self, outbound: bool = False) -> str: """Get PCAP the logger name.""" + if self.port_name: + return f"{self.hostname}_{self.port_name}_{'outbound' if outbound else 'inbound'}_pcap" if self.ip_address: return f"{self.hostname}_{self.ip_address}_{'outbound' if outbound else 'inbound'}_pcap" - if self.interface_num: - return f"{self.hostname}_port-{self.interface_num}_{'outbound' if outbound else 'inbound'}_pcap" + if self.port_num: + return f"{self.hostname}_port-{self.port_num}_{'outbound' if outbound else 'inbound'}_pcap" return f"{self.hostname}_{'outbound' if outbound else 'inbound'}_pcap" def _get_log_path(self, outbound: bool = False) -> Path: @@ -97,8 +109,9 @@ class PacketCapture: :param frame: The PCAP frame to capture. """ - msg = frame.model_dump_json() - self.inbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL + if SIM_OUTPUT.save_pcap_logs: + msg = frame.model_dump_json() + self.inbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL def capture_outbound(self, frame): # noqa - I'll have a circular import and cant use if TYPE_CHECKING ;( """ @@ -106,5 +119,6 @@ class PacketCapture: :param frame: The PCAP frame to capture. """ - msg = frame.model_dump_json() - self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL + if SIM_OUTPUT.save_pcap_logs: + msg = frame.model_dump_json() + self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL diff --git a/src/primaite/simulator/system/processes/process.py b/src/primaite/simulator/system/processes/process.py index b753e3ad..458a6b5c 100644 --- a/src/primaite/simulator/system/processes/process.py +++ b/src/primaite/simulator/system/processes/process.py @@ -24,12 +24,6 @@ class Process(Software): operating_state: ProcessOperatingState "The current operating state of the Process." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/src/primaite/simulator/system/services/database/database_service.py b/src/primaite/simulator/system/services/database/database_service.py index 0b9554d5..c73132eb 100644 --- a/src/primaite/simulator/system/services/database/database_service.py +++ b/src/primaite/simulator/system/services/database/database_service.py @@ -23,6 +23,7 @@ class DatabaseService(Service): """ password: Optional[str] = None + """Password that needs to be provided by clients if they want to connect to the DatabaseService.""" backup_server_ip: IPv4Address = None """IP address of the backup server.""" @@ -40,25 +41,6 @@ class DatabaseService(Service): super().__init__(**kwargs) self._create_db_file() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DatabaseService original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = { - "password", - "connections", - "backup_server_ip", - "latest_backup_directory", - "latest_backup_file_name", - } - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug("Resetting DatabaseService original state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def configure_backup(self, backup_server: IPv4Address): """ Set up the database backup. @@ -239,6 +221,18 @@ class DatabaseService(Service): } else: return {"status_code": 404, "data": False} + elif query == "SELECT * FROM pg_stat_activity": + # Check if the connection is active. + if self.health_state_actual == SoftwareHealthState.GOOD: + return { + "status_code": 200, + "type": "sql", + "data": False, + "uuid": query_id, + "connection_id": connection_id, + } + else: + return {"status_code": 401, "data": False} else: # Invalid query self.sys_log.info(f"{self.name}: Invalid {query}") diff --git a/src/primaite/simulator/system/services/dns/dns_client.py b/src/primaite/simulator/system/services/dns/dns_client.py index 2d3879ff..967af6b2 100644 --- a/src/primaite/simulator/system/services/dns/dns_client.py +++ b/src/primaite/simulator/system/services/dns/dns_client.py @@ -29,18 +29,6 @@ class DNSClient(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_server"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_cache.clear() - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/dns/dns_server.py b/src/primaite/simulator/system/services/dns/dns_server.py index 8decf7e9..4d0ebbb8 100644 --- a/src/primaite/simulator/system/services/dns/dns_server.py +++ b/src/primaite/simulator/system/services/dns/dns_server.py @@ -28,20 +28,6 @@ class DNSServer(Service): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting DNSServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"dns_table"} - self._original_state["dns_table_orig"] = self.model_dump(include=vals_to_include)["dns_table"] - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - self.dns_table.clear() - for key, value in self._original_state["dns_table_orig"].items(): - self.dns_table[key] = value - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Describes the current state of the software. diff --git a/src/primaite/simulator/system/services/ftp/ftp_client.py b/src/primaite/simulator/system/services/ftp/ftp_client.py index 39bc57f0..7c334ced 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_client.py +++ b/src/primaite/simulator/system/services/ftp/ftp_client.py @@ -27,18 +27,6 @@ class FTPClient(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPClient original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"connected"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPClient state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ftp/ftp_server.py b/src/primaite/simulator/system/services/ftp/ftp_server.py index a82b0919..c5330de2 100644 --- a/src/primaite/simulator/system/services/ftp/ftp_server.py +++ b/src/primaite/simulator/system/services/ftp/ftp_server.py @@ -27,19 +27,6 @@ class FTPServer(FTPServiceABC): super().__init__(**kwargs) self.start() - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting FTPServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"server_password"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting FTPServer state on node {self.software_manager.node.hostname}") - self.clear_connections() - super().reset_component_for_episode(episode) - def _process_ftp_command(self, payload: FTPPacket, session_id: Optional[str] = None, **kwargs) -> FTPPacket: """ Process the command in the FTP Packet. diff --git a/src/primaite/simulator/system/services/ntp/ntp_client.py b/src/primaite/simulator/system/services/ntp/ntp_client.py index 43d1d783..ad00065c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_client.py +++ b/src/primaite/simulator/system/services/ntp/ntp_client.py @@ -49,21 +49,12 @@ class NTPClient(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including resetting any - stateful properties or statistics, and clearing any message queues. - """ - pass - def send( self, payload: NTPPacket, session_id: Optional[str] = None, dest_ip_address: IPv4Address = None, - dest_port: [Port] = Port.NTP, + dest_port: Port = Port.NTP, **kwargs, ) -> bool: """Requests NTP data from NTP server. diff --git a/src/primaite/simulator/system/services/ntp/ntp_server.py b/src/primaite/simulator/system/services/ntp/ntp_server.py index 3ae80936..f9d9ee7c 100644 --- a/src/primaite/simulator/system/services/ntp/ntp_server.py +++ b/src/primaite/simulator/system/services/ntp/ntp_server.py @@ -34,16 +34,6 @@ class NTPServer(Service): state = super().describe_state() return state - def reset_component_for_episode(self, episode: int): - """ - Resets the Service component for a new episode. - - This method ensures the Service is ready for a new episode, including - resetting any stateful properties or statistics, and clearing any message - queues. - """ - pass - def receive( self, payload: NTPPacket, diff --git a/src/primaite/simulator/system/services/service.py b/src/primaite/simulator/system/services/service.py index 162678a0..b2a6f685 100644 --- a/src/primaite/simulator/system/services/service.py +++ b/src/primaite/simulator/system/services/service.py @@ -3,6 +3,7 @@ from enum import Enum from typing import Any, Dict, Optional from primaite import getLogger +from primaite.interface.request import RequestResponse from primaite.simulator.core import RequestManager, RequestType from primaite.simulator.system.software import IOSoftware, SoftwareHealthState @@ -58,7 +59,7 @@ class Service(IOSoftware): if self.operating_state is not ServiceOperatingState.RUNNING: # service is not running - _LOGGER.error(f"Cannot perform action: {self.name} is {self.operating_state.name}") + _LOGGER.debug(f"Cannot perform action: {self.name} is {self.operating_state.name}") return False return True @@ -78,22 +79,21 @@ class Service(IOSoftware): """ return super().receive(payload=payload, session_id=session_id, **kwargs) - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"operating_state", "restart_duration", "restart_countdown"} - self._original_state.update(self.model_dump(include=vals_to_include)) - def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() - rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) - rm.add_request("stop", RequestType(func=lambda request, context: self.stop())) - rm.add_request("start", RequestType(func=lambda request, context: self.start())) - rm.add_request("pause", RequestType(func=lambda request, context: self.pause())) - rm.add_request("resume", RequestType(func=lambda request, context: self.resume())) - rm.add_request("restart", RequestType(func=lambda request, context: self.restart())) - rm.add_request("disable", RequestType(func=lambda request, context: self.disable())) - rm.add_request("enable", RequestType(func=lambda request, context: self.enable())) + rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))) + rm.add_request("stop", RequestType(func=lambda request, context: RequestResponse.from_bool(self.stop()))) + rm.add_request("start", RequestType(func=lambda request, context: RequestResponse.from_bool(self.start()))) + rm.add_request("pause", RequestType(func=lambda request, context: RequestResponse.from_bool(self.pause()))) + rm.add_request("resume", RequestType(func=lambda request, context: RequestResponse.from_bool(self.resume()))) + rm.add_request("restart", RequestType(func=lambda request, context: RequestResponse.from_bool(self.restart()))) + rm.add_request("disable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.disable()))) + rm.add_request("enable", RequestType(func=lambda request, context: RequestResponse.from_bool(self.enable()))) return rm @abstractmethod @@ -112,17 +112,19 @@ class Service(IOSoftware): state["health_state_visible"] = self.health_state_visible.value return state - def stop(self) -> None: + def stop(self) -> bool: """Stop the service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Stopping service {self.name}") self.operating_state = ServiceOperatingState.STOPPED + return True + return False - def start(self, **kwargs) -> None: + def start(self, **kwargs) -> bool: """Start the service.""" # cant start the service if the node it is on is off if not super()._can_perform_action(): - return + return False if self.operating_state == ServiceOperatingState.STOPPED: self.sys_log.info(f"Starting service {self.name}") @@ -130,36 +132,47 @@ class Service(IOSoftware): # set software health state to GOOD if initially set to UNUSED if self.health_state_actual == SoftwareHealthState.UNUSED: self.set_health_state(SoftwareHealthState.GOOD) + return True + return False - def pause(self) -> None: + def pause(self) -> bool: """Pause the service.""" if self.operating_state == ServiceOperatingState.RUNNING: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.PAUSED + return True + return False - def resume(self) -> None: + def resume(self) -> bool: """Resume paused service.""" if self.operating_state == ServiceOperatingState.PAUSED: self.sys_log.info(f"Resuming service {self.name}") self.operating_state = ServiceOperatingState.RUNNING + return True + return False - def restart(self) -> None: + def restart(self) -> bool: """Restart running service.""" if self.operating_state in [ServiceOperatingState.RUNNING, ServiceOperatingState.PAUSED]: self.sys_log.info(f"Pausing service {self.name}") self.operating_state = ServiceOperatingState.RESTARTING self.restart_countdown = self.restart_duration + return True + return False - def disable(self) -> None: + def disable(self) -> bool: """Disable the service.""" self.sys_log.info(f"Disabling Application {self.name}") self.operating_state = ServiceOperatingState.DISABLED + return True - def enable(self) -> None: + def enable(self) -> bool: """Enable the disabled service.""" if self.operating_state == ServiceOperatingState.DISABLED: self.sys_log.info(f"Enabling Application {self.name}") self.operating_state = ServiceOperatingState.STOPPED + return True + return False def apply_timestep(self, timestep: int) -> None: """ diff --git a/src/primaite/simulator/system/services/web_server/web_server.py b/src/primaite/simulator/system/services/web_server/web_server.py index eaea6bb1..5e7591e9 100644 --- a/src/primaite/simulator/system/services/web_server/web_server.py +++ b/src/primaite/simulator/system/services/web_server/web_server.py @@ -23,18 +23,6 @@ class WebServer(Service): last_response_status_code: Optional[HttpStatusCode] = None - def set_original_state(self): - """Sets the original state.""" - _LOGGER.debug(f"Setting WebServer original state on node {self.software_manager.node.hostname}") - super().set_original_state() - vals_to_include = {"last_response_status_code"} - self._original_state.update(self.model_dump(include=vals_to_include)) - - def reset_component_for_episode(self, episode: int): - """Reset the original state of the SimComponent.""" - _LOGGER.debug(f"Resetting WebServer state on node {self.software_manager.node.hostname}") - super().reset_component_for_episode(episode) - def describe_state(self) -> Dict: """ Produce a dictionary describing the current state of this object. @@ -130,7 +118,7 @@ class WebServer(Service): self.set_health_state(SoftwareHealthState.COMPROMISED) return response - except Exception: + except Exception: # TODO: refactor this. Likely to cause silent bugs. (ADO ticket #2345 ) # something went wrong on the server response.status_code = HttpStatusCode.INTERNAL_SERVER_ERROR return response diff --git a/src/primaite/simulator/system/software.py b/src/primaite/simulator/system/software.py index ce39930b..d55f141f 100644 --- a/src/primaite/simulator/system/software.py +++ b/src/primaite/simulator/system/software.py @@ -3,8 +3,9 @@ from abc import abstractmethod from datetime import datetime from enum import Enum from ipaddress import IPv4Address, IPv4Network -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, TYPE_CHECKING, Union +from primaite.interface.request import RequestResponse from primaite.simulator.core import _LOGGER, RequestManager, RequestType, SimComponent from primaite.simulator.file_system.file_system import FileSystem, Folder from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState @@ -13,6 +14,9 @@ from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.session_manager import Session from primaite.simulator.system.core.sys_log import SysLog +if TYPE_CHECKING: + from primaite.simulator.system.core.software_manager import SoftwareManager + class SoftwareType(Enum): """ @@ -84,7 +88,7 @@ class Software(SimComponent): "The count of times the software has been scanned, defaults to 0." revealed_to_red: bool = False "Indicates if the software has been revealed to red agent, defaults is False." - software_manager: Any = None + software_manager: "SoftwareManager" = None "An instance of Software Manager that is used by the parent node." sys_log: SysLog = None "An instance of SysLog that is used by the parent node." @@ -97,34 +101,28 @@ class Software(SimComponent): _patching_countdown: Optional[int] = None "Current number of ticks left to patch the software." - def set_original_state(self): - """Sets the original state.""" - vals_to_include = { - "name", - "health_state_actual", - "health_state_visible", - "criticality", - "patching_count", - "scanning_count", - "revealed_to_red", - } - self._original_state = self.model_dump(include=vals_to_include) - def _init_request_manager(self) -> RequestManager: + """ + Initialise the request manager. + + More information in user guide and docstring for SimComponent._init_request_manager. + """ rm = super()._init_request_manager() rm.add_request( "compromise", RequestType( - func=lambda request, context: self.set_health_state(SoftwareHealthState.COMPROMISED), + func=lambda request, context: RequestResponse.from_bool( + self.set_health_state(SoftwareHealthState.COMPROMISED) + ), ), ) rm.add_request( "patch", RequestType( - func=lambda request, context: self.patch(), + func=lambda request, context: RequestResponse.from_bool(self.patch()), ), ) - rm.add_request("scan", RequestType(func=lambda request, context: self.scan())) + rm.add_request("scan", RequestType(func=lambda request, context: RequestResponse.from_bool(self.scan()))) return rm def _get_session_details(self, session_id: str) -> Session: @@ -158,7 +156,7 @@ class Software(SimComponent): ) return state - def set_health_state(self, health_state: SoftwareHealthState) -> None: + def set_health_state(self, health_state: SoftwareHealthState) -> bool: """ Assign a new health state to this software. @@ -170,6 +168,7 @@ class Software(SimComponent): :type health_state: SoftwareHealthState """ self.health_state_actual = health_state + return True def install(self) -> None: """ @@ -190,15 +189,18 @@ class Software(SimComponent): """ pass - def scan(self) -> None: + def scan(self) -> bool: """Update the observed health status to match the actual health status.""" self.health_state_visible = self.health_state_actual + return True - def patch(self) -> None: + def patch(self) -> bool: """Perform a patch on the software.""" if self.health_state_actual in (SoftwareHealthState.COMPROMISED, SoftwareHealthState.GOOD): self._patching_countdown = self.patching_duration self.set_health_state(SoftwareHealthState.PATCHING) + return True + return False def _update_patch_status(self) -> None: """Update the patch status of the software.""" @@ -248,12 +250,6 @@ class IOSoftware(Software): _connections: Dict[str, Dict] = {} "Active connections." - def set_original_state(self): - """Sets the original state.""" - super().set_original_state() - vals_to_include = {"installing_count", "max_sessions", "tcp", "udp", "port"} - self._original_state.update(self.model_dump(include=vals_to_include)) - @abstractmethod def describe_state(self) -> Dict: """ diff --git a/tests/assets/configs/bad_primaite_session.yaml b/tests/assets/configs/bad_primaite_session.yaml index 73801503..38d54ce3 100644 --- a/tests/assets/configs/bad_primaite_session.yaml +++ b/tests/assets/configs/bad_primaite_session.yaml @@ -21,7 +21,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -46,7 +46,7 @@ agents: frequency: 20 variance: 5 - - ref: client_1_data_manipulation_red_bot + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -589,15 +589,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server diff --git a/tests/assets/configs/basic_firewall.yaml b/tests/assets/configs/basic_firewall.yaml new file mode 100644 index 00000000..9d7b34cb --- /dev/null +++ b/tests/assets/configs/basic_firewall.yaml @@ -0,0 +1,174 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# + +training_config: + rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_2_green_user + team: GREEN + type: ProbabilisticAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_2 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 + +simulation: + network: + nodes: + + - ref: firewall + type: firewall + hostname: firewall + start_up_duration: 0 + shut_down_duration: 0 + ports: + external_port: # port 1 + ip_address: 192.168.20.1 + subnet_mask: 255.255.255.0 + internal_port: # port 2 + ip_address: 192.168.1.2 + subnet_mask: 255.255.255.0 + acl: + internal_inbound_acl: + 21: + action: PERMIT + protocol: TCP + 22: + action: PERMIT + protocol: UDP + 23: + action: PERMIT + protocol: ICMP + internal_outbound_acl: + 21: + action: PERMIT + protocol: TCP + 22: + action: PERMIT + protocol: UDP + 23: + action: PERMIT + protocol: ICMP + dmz_inbound_acl: + 21: + action: PERMIT + protocol: TCP + 22: + action: PERMIT + protocol: UDP + 23: + action: PERMIT + protocol: ICMP + dmz_outbound_acl: + 21: + action: PERMIT + protocol: TCP + 22: + action: PERMIT + protocol: UDP + 23: + action: PERMIT + protocol: ICMP + + - ref: switch_1 + type: switch + hostname: switch_1 + num_ports: 8 + - ref: switch_2 + type: switch + hostname: switch_2 + num_ports: 8 + + - ref: client_1 + type: computer + hostname: client_1 + ip_address: 192.168.10.21 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + # pre installed services and applications + - ref: client_2 + type: computer + hostname: client_2 + ip_address: 192.168.10.22 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + # pre installed services and applications + + links: + - ref: switch_1___client_1 + endpoint_a_ref: switch_1 + endpoint_a_port: 1 + endpoint_b_ref: client_1 + endpoint_b_port: 1 + - ref: switch_2___client_2 + endpoint_a_ref: switch_2 + endpoint_a_port: 1 + endpoint_b_ref: client_2 + endpoint_b_port: 1 + - ref: switch_1___firewall + endpoint_a_ref: switch_1 + endpoint_a_port: 2 + endpoint_b_ref: firewall + endpoint_b_port: 1 + - ref: switch_2___firewall + endpoint_a_ref: switch_2 + endpoint_a_port: 2 + endpoint_b_ref: firewall + endpoint_b_port: 2 diff --git a/tests/assets/configs/basic_switched_network.yaml b/tests/assets/configs/basic_switched_network.yaml index d1cec079..9a0d5313 100644 --- a/tests/assets/configs/basic_switched_network.yaml +++ b/tests/assets/configs/basic_switched_network.yaml @@ -1,3 +1,10 @@ +# Basic Switched network +# +# -------------- -------------- -------------- +# | client_1 |------| switch_1 |------| client_2 | +# -------------- -------------- -------------- +# + training_config: rl_framework: SB3 rl_algorithm: PPO @@ -33,7 +40,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -134,6 +141,17 @@ simulation: default_gateway: 192.168.10.1 dns_server: 192.168.1.10 # pre installed services and applications + - ref: client_3 + type: computer + hostname: client_3 + ip_address: 192.168.10.23 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.1.10 + start_up_duration: 0 + shut_down_duration: 0 + operating_state: "OFF" + # pre installed services and applications links: - ref: switch_1___client_1 diff --git a/tests/assets/configs/dmz_network.yaml b/tests/assets/configs/dmz_network.yaml new file mode 100644 index 00000000..95e09e16 --- /dev/null +++ b/tests/assets/configs/dmz_network.yaml @@ -0,0 +1,300 @@ +# Network with DMZ +# +# An example network configuration with an internal network, a DMZ network and a couple of external networks. +# +# ............................................................................ +# . . +# . Internal Network . +# . . +# . -------------- -------------- -------------- . +# . | client_1 |------| switch_1 |--------| router_1 | . +# . -------------- -------------- -------------- . +# . (Computer) | . +# ........................................................|................... +# | +# | +# ........................................................|................... +# . | . +# . DMZ Network | . +# . | . +# . ---------------- -------------- -------------- . +# . | dmz_server |------| switch_2 |------| firewall | . +# . ---------------- -------------- -------------- . +# . (Server) | . +# ........................................................|................... +# | +# External Network | +# | +# | +# ----------------------- -------------- --------------------- +# | external_computer |------| switch_3 |------| external_server | +# ----------------------- -------------- --------------------- +# +training_config: + rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + +agents: + - ref: client_1_green_user + team: GREEN + type: ProbabilisticAgent + observation_space: + type: UC2GreenObservation + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client_1 + applications: + - application_name: WebBrowser + max_folders_per_node: 1 + max_files_per_folder: 1 + max_services_per_node: 1 + max_applications_per_node: 1 + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 5 + frequency: 4 + variance: 3 + + +simulation: + network: + nodes: + - ref: client_1 + type: computer + hostname: client_1 + ip_address: 192.168.0.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.0.1 + dns_server: 192.168.20.11 + start_up_duration: 0 + shut_down_duration: 0 + + - ref: switch_1 + type: switch + hostname: switch_1 + num_ports: 8 + start_up_duration: 0 + shut_down_duration: 0 + + - ref: router_1 + type: router + hostname: router_1 + num_ports: 5 + start_up_duration: 0 + shut_down_duration: 0 + ports: + 1: + ip_address: 192.168.0.1 + subnet_mask: 255.255.255.0 + 2: + ip_address: 192.168.1.1 + subnet_mask: 255.255.255.0 + acl: + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + routes: + - address: 192.168.10.10 # route to dmz_server + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.2 + metric: 0 + - address: 192.168.20.10 # route to external_computer + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.2 + metric: 0 + - address: 192.168.20.11 # route to external_server + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.2 + metric: 0 + + - ref: dmz_server + type: server + hostname: dmz_server + ip_address: 192.168.10.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.10.1 + dns_server: 192.168.20.11 + start_up_duration: 0 + shut_down_duration: 0 + + - ref: switch_2 + type: switch + hostname: switch_2 + num_ports: 8 + start_up_duration: 0 + shut_down_duration: 0 + + - ref: firewall + type: firewall + hostname: firewall + start_up_duration: 0 + shut_down_duration: 0 + ports: + external_port: # port 1 + ip_address: 192.168.20.1 + subnet_mask: 255.255.255.0 + internal_port: # port 2 + ip_address: 192.168.1.2 + subnet_mask: 255.255.255.0 + dmz_port: # port 3 + ip_address: 192.168.10.1 + subnet_mask: 255.255.255.0 + acl: + internal_inbound_acl: + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + internal_outbound_acl: + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + dmz_inbound_acl: + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + dmz_outbound_acl: + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + 23: + action: PERMIT + protocol: ICMP + external_inbound_acl: + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + external_outbound_acl: + 22: + action: PERMIT + src_port: ARP + dst_port: ARP + routes: + - address: 192.168.0.10 # route to client_1 + subnet_mask: 255.255.255.0 + next_hop_ip_address: 192.168.1.1 + metric: 0 + + - ref: switch_3 + type: switch + hostname: switch_3 + num_ports: 8 + start_up_duration: 0 + shut_down_duration: 0 + + - ref: external_computer + type: computer + hostname: external_computer + ip_address: 192.168.20.10 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.20.1 + dns_server: 192.168.20.11 + start_up_duration: 0 + shut_down_duration: 0 + + - ref: external_server + type: server + hostname: external_server + ip_address: 192.168.20.11 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.20.1 + start_up_duration: 0 + shut_down_duration: 0 + services: + - ref: domain_controller_dns_server + type: DNSServer + links: + - ref: client_1___switch_1 + endpoint_a_ref: client_1 + endpoint_a_port: 1 + endpoint_b_ref: switch_1 + endpoint_b_port: 1 + - ref: router_1___switch_1 + endpoint_a_ref: router_1 + endpoint_a_port: 1 + endpoint_b_ref: switch_1 + endpoint_b_port: 8 + - ref: router_1___firewall + endpoint_a_ref: firewall + endpoint_a_port: 2 # internal firewall port + endpoint_b_ref: router_1 + endpoint_b_port: 2 + - ref: firewall___switch_2 + endpoint_a_ref: firewall + endpoint_a_port: 3 # dmz firewall port + endpoint_b_ref: switch_2 + endpoint_b_port: 8 + - ref: dmz_server___switch_2 + endpoint_a_ref: dmz_server + endpoint_a_port: 1 + endpoint_b_ref: switch_2 + endpoint_b_port: 1 + - ref: firewall___switch_3 + endpoint_a_ref: firewall + endpoint_a_port: 1 # external firewall port + endpoint_b_ref: switch_3 + endpoint_b_port: 8 + - ref: external_computer___switch_3 + endpoint_a_ref: external_computer + endpoint_a_port: 1 + endpoint_b_ref: switch_3 + endpoint_b_port: 1 + - ref: external_server___switch_3 + endpoint_a_ref: external_server + endpoint_a_port: 1 + endpoint_b_ref: switch_3 + endpoint_b_port: 2 diff --git a/tests/assets/configs/eval_only_primaite_session.yaml b/tests/assets/configs/eval_only_primaite_session.yaml index 985e764a..f2815578 100644 --- a/tests/assets/configs/eval_only_primaite_session.yaml +++ b/tests/assets/configs/eval_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -51,7 +51,7 @@ agents: frequency: 20 variance: 5 - - ref: client_1_data_manipulation_red_bot + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -593,15 +593,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -624,7 +625,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/multi_agent_session.yaml b/tests/assets/configs/multi_agent_session.yaml index fd5bbbe0..8bbddb76 100644 --- a/tests/assets/configs/multi_agent_session.yaml +++ b/tests/assets/configs/multi_agent_session.yaml @@ -31,7 +31,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -57,7 +57,7 @@ agents: frequency: 20 variance: 5 - - ref: client_1_data_manipulation_red_bot + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -1043,16 +1043,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - - ref: database_server type: server @@ -1074,7 +1074,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/no_nodes_links_agents_network.yaml b/tests/assets/configs/no_nodes_links_agents_network.yaml new file mode 100644 index 00000000..607a899a --- /dev/null +++ b/tests/assets/configs/no_nodes_links_agents_network.yaml @@ -0,0 +1,31 @@ +training_config: + rl_framework: SB3 + rl_algorithm: PPO + seed: 333 + n_learn_episodes: 1 + n_eval_episodes: 5 + max_steps_per_episode: 128 + deterministic_eval: false + n_agents: 1 + agent_references: + - defender + +io_settings: + save_checkpoints: true + checkpoint_interval: 5 + save_step_metadata: false + save_pcap_logs: true + save_sys_logs: true + + +game: + max_episode_length: 256 + ports: + - ARP + - DNS + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP diff --git a/src/primaite/config/_package_data/example_config.yaml b/tests/assets/configs/shared_rewards.yaml similarity index 72% rename from src/primaite/config/_package_data/example_config.yaml rename to tests/assets/configs/shared_rewards.yaml index 6eab6c54..daffa585 100644 --- a/src/primaite/config/_package_data/example_config.yaml +++ b/tests/assets/configs/shared_rewards.yaml @@ -11,29 +11,36 @@ training_config: - defender io_settings: - save_checkpoints: true - checkpoint_interval: 5 + save_agent_actions: false save_step_metadata: false - save_pcap_logs: true - save_sys_logs: true + save_pcap_logs: false + save_sys_logs: false game: max_episode_length: 256 ports: - - ARP - - DNS - HTTP - POSTGRES_SERVER protocols: - ICMP - TCP - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -45,24 +52,45 @@ agents: - node_name: client_2 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 reward_function: reward_components: - - type: DUMMY - - agent_settings: - start_settings: - start_step: 5 - frequency: 4 - variance: 3 + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_2 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_2 - ref: client_1_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.3 + 1: 0.6 + 2: 0.1 observation_space: type: UC2GreenObservation action_space: @@ -74,18 +102,42 @@ agents: - node_name: client_1 applications: - application_name: WebBrowser + - application_name: DatabaseClient max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 - max_applications_per_node: 1 + max_applications_per_node: 2 + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + 2: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 1 + reward_function: reward_components: - - type: DUMMY + - type: WEBPAGE_UNAVAILABLE_PENALTY + weight: 0.25 + options: + node_hostname: client_1 + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 0.05 + options: + node_hostname: client_1 - - ref: client_1_data_manipulation_red_bot + + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -98,14 +150,14 @@ agents: action_list: - type: DONOTHING - type: NODE_APPLICATION_EXECUTE - - type: NODE_FILE_DELETE - - type: NODE_FILE_CORRUPT - - type: NODE_OS_SCAN options: nodes: - node_name: client_1 applications: - application_name: DataManipulationBot + - node_name: client_2 + applications: + - application_name: DataManipulationBot max_folders_per_node: 1 max_files_per_folder: 1 max_services_per_node: 1 @@ -235,99 +287,196 @@ agents: 3: action: "NODE_SERVICE_START" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 4: action: "NODE_SERVICE_PAUSE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 5: action: "NODE_SERVICE_RESUME" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 6: action: "NODE_SERVICE_RESTART" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 7: action: "NODE_SERVICE_DISABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 8: action: "NODE_SERVICE_ENABLE" options: - node_id: 1 - service_id: 0 + node_id: 1 + service_id: 0 9: # check database.db file action: "NODE_FILE_SCAN" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 10: action: "NODE_FILE_CHECKHASH" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 11: action: "NODE_FILE_DELETE" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 12: action: "NODE_FILE_REPAIR" options: - node_id: 2 - folder_id: 0 - file_id: 0 + node_id: 2 + folder_id: 0 + file_id: 0 13: action: "NODE_SERVICE_PATCH" options: - node_id: 2 - service_id: 0 + node_id: 2 + service_id: 0 14: action: "NODE_FOLDER_SCAN" options: - node_id: 2 - folder_id: 0 + node_id: 2 + folder_id: 0 15: action: "NODE_FOLDER_CHECKHASH" options: - node_id: 2 - folder_id: 0 + node_id: 2 + folder_id: 0 16: action: "NODE_FOLDER_REPAIR" options: - node_id: 2 - folder_id: 0 + node_id: 2 + folder_id: 0 17: action: "NODE_FOLDER_RESTORE" options: - node_id: 2 - folder_id: 0 + node_id: 2 + folder_id: 0 18: action: "NODE_OS_SCAN" options: - node_id: 2 - 19: # shutdown client 1 + node_id: 0 + 19: action: "NODE_SHUTDOWN" options: - node_id: 5 + node_id: 0 20: - action: "NODE_STARTUP" + action: NODE_STARTUP options: - node_id: 5 + node_id: 0 21: - action: "NODE_RESET" + action: NODE_RESET options: - node_id: 5 - 22: # "ACL: ADDRULE - Block outgoing traffic from client 1" + node_id: 0 + 22: + action: "NODE_OS_SCAN" + options: + node_id: 1 + 23: + action: "NODE_SHUTDOWN" + options: + node_id: 1 + 24: + action: NODE_STARTUP + options: + node_id: 1 + 25: + action: NODE_RESET + options: + node_id: 1 + 26: # old action num: 18 + action: "NODE_OS_SCAN" + options: + node_id: 2 + 27: + action: "NODE_SHUTDOWN" + options: + node_id: 2 + 28: + action: NODE_STARTUP + options: + node_id: 2 + 29: + action: NODE_RESET + options: + node_id: 2 + 30: + action: "NODE_OS_SCAN" + options: + node_id: 3 + 31: + action: "NODE_SHUTDOWN" + options: + node_id: 3 + 32: + action: NODE_STARTUP + options: + node_id: 3 + 33: + action: NODE_RESET + options: + node_id: 3 + 34: + action: "NODE_OS_SCAN" + options: + node_id: 4 + 35: + action: "NODE_SHUTDOWN" + options: + node_id: 4 + 36: + action: NODE_STARTUP + options: + node_id: 4 + 37: + action: NODE_RESET + options: + node_id: 4 + 38: + action: "NODE_OS_SCAN" + options: + node_id: 5 + 39: # old action num: 19 # shutdown client 1 + action: "NODE_SHUTDOWN" + options: + node_id: 5 + 40: # old action num: 20 + action: NODE_STARTUP + options: + node_id: 5 + 41: # old action num: 21 + action: NODE_RESET + options: + node_id: 5 + 42: + action: "NODE_OS_SCAN" + options: + node_id: 6 + 43: + action: "NODE_SHUTDOWN" + options: + node_id: 6 + 44: + action: NODE_STARTUP + options: + node_id: 6 + 45: + action: NODE_RESET + options: + node_id: 6 + + 46: # old action num: 22 # "ACL: ADDRULE - Block outgoing traffic from client 1" action: "NETWORK_ACL_ADDRULE" options: position: 1 @@ -337,7 +486,7 @@ agents: source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 23: # "ACL: ADDRULE - Block outgoing traffic from client 2" + 47: # old action num: 23 # "ACL: ADDRULE - Block outgoing traffic from client 2" action: "NETWORK_ACL_ADDRULE" options: position: 2 @@ -347,7 +496,7 @@ agents: source_port_id: 1 dest_port_id: 1 protocol_id: 1 - 24: # block tcp traffic from client 1 to web app + 48: # old action num: 24 # block tcp traffic from client 1 to web app action: "NETWORK_ACL_ADDRULE" options: position: 3 @@ -357,7 +506,7 @@ agents: source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 25: # block tcp traffic from client 2 to web app + 49: # old action num: 25 # block tcp traffic from client 2 to web app action: "NETWORK_ACL_ADDRULE" options: position: 4 @@ -367,7 +516,7 @@ agents: source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 26: + 50: # old action num: 26 action: "NETWORK_ACL_ADDRULE" options: position: 5 @@ -377,7 +526,7 @@ agents: source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 27: + 51: # old action num: 27 action: "NETWORK_ACL_ADDRULE" options: position: 6 @@ -387,128 +536,129 @@ agents: source_port_id: 1 dest_port_id: 1 protocol_id: 3 - 28: + 52: # old action num: 28 action: "NETWORK_ACL_REMOVERULE" options: position: 0 - 29: + 53: # old action num: 29 action: "NETWORK_ACL_REMOVERULE" options: position: 1 - 30: + 54: # old action num: 30 action: "NETWORK_ACL_REMOVERULE" options: position: 2 - 31: + 55: # old action num: 31 action: "NETWORK_ACL_REMOVERULE" options: position: 3 - 32: + 56: # old action num: 32 action: "NETWORK_ACL_REMOVERULE" options: position: 4 - 33: + 57: # old action num: 33 action: "NETWORK_ACL_REMOVERULE" options: position: 5 - 34: + 58: # old action num: 34 action: "NETWORK_ACL_REMOVERULE" options: position: 6 - 35: + 59: # old action num: 35 action: "NETWORK_ACL_REMOVERULE" options: position: 7 - 36: + 60: # old action num: 36 action: "NETWORK_ACL_REMOVERULE" options: position: 8 - 37: + 61: # old action num: 37 action: "NETWORK_ACL_REMOVERULE" options: position: 9 - 38: + 62: # old action num: 38 action: "NETWORK_NIC_DISABLE" options: node_id: 0 nic_id: 0 - 39: + 63: # old action num: 39 action: "NETWORK_NIC_ENABLE" options: node_id: 0 nic_id: 0 - 40: + 64: # old action num: 40 action: "NETWORK_NIC_DISABLE" options: node_id: 1 nic_id: 0 - 41: + 65: # old action num: 41 action: "NETWORK_NIC_ENABLE" options: node_id: 1 nic_id: 0 - 42: + 66: # old action num: 42 action: "NETWORK_NIC_DISABLE" options: node_id: 2 nic_id: 0 - 43: + 67: # old action num: 43 action: "NETWORK_NIC_ENABLE" options: node_id: 2 nic_id: 0 - 44: + 68: # old action num: 44 action: "NETWORK_NIC_DISABLE" options: node_id: 3 nic_id: 0 - 45: + 69: # old action num: 45 action: "NETWORK_NIC_ENABLE" options: node_id: 3 nic_id: 0 - 46: + 70: # old action num: 46 action: "NETWORK_NIC_DISABLE" options: node_id: 4 nic_id: 0 - 47: + 71: # old action num: 47 action: "NETWORK_NIC_ENABLE" options: node_id: 4 nic_id: 0 - 48: + 72: # old action num: 48 action: "NETWORK_NIC_DISABLE" options: node_id: 4 nic_id: 1 - 49: + 73: # old action num: 49 action: "NETWORK_NIC_ENABLE" options: node_id: 4 nic_id: 1 - 50: + 74: # old action num: 50 action: "NETWORK_NIC_DISABLE" options: node_id: 5 nic_id: 0 - 51: + 75: # old action num: 51 action: "NETWORK_NIC_ENABLE" options: node_id: 5 nic_id: 0 - 52: + 76: # old action num: 52 action: "NETWORK_NIC_DISABLE" options: node_id: 6 nic_id: 0 - 53: + 77: # old action num: 53 action: "NETWORK_NIC_ENABLE" options: node_id: 6 nic_id: 0 + options: nodes: - node_name: domain_controller @@ -555,20 +705,15 @@ agents: reward_function: reward_components: - - type: DATABASE_FILE_INTEGRITY - weight: 0.34 + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: database_server - folder_name: database - file_name: database.db - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.33 + agent_name: client_1_green_user + - type: SHARED_REWARD + weight: 1.0 options: - node_hostname: client_1 - - type: WEBPAGE_UNAVAILABLE_PENALTY - weight: 0.33 - options: - node_hostname: client_2 + agent_name: client_2_green_user + agent_settings: @@ -580,11 +725,15 @@ agents: simulation: network: + nmne_config: + capture_nmne: true + nmne_capture_keywords: + - DELETE nodes: - ref: router_1 - type: router hostname: router_1 + type: router num_ports: 5 ports: 1: @@ -619,18 +768,18 @@ simulation: protocol: ICMP - ref: switch_1 - type: switch hostname: switch_1 + type: switch num_ports: 8 - ref: switch_2 - type: switch hostname: switch_2 + type: switch num_ports: 8 - ref: domain_controller - type: server hostname: domain_controller + type: server ip_address: 192.168.1.10 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -642,24 +791,25 @@ simulation: arcd.com: 192.168.1.12 # web server - ref: web_server - type: server hostname: web_server + type: server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server - type: server hostname: database_server + type: server ip_address: 192.168.1.14 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -673,8 +823,8 @@ simulation: type: FTPClient - ref: backup_server - type: server hostname: backup_server + type: server ip_address: 192.168.1.16 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -684,8 +834,8 @@ simulation: type: FTPServer - ref: security_suite - type: server hostname: security_suite + type: server ip_address: 192.168.1.110 subnet_mask: 255.255.255.0 default_gateway: 192.168.1.1 @@ -696,8 +846,8 @@ simulation: subnet_mask: 255.255.255.0 - ref: client_1 - type: computer hostname: client_1 + type: computer ip_address: 192.168.10.21 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -714,13 +864,17 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: client_1_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_1_dns_client type: DNSClient - ref: client_2 - type: computer hostname: client_2 + type: computer ip_address: 192.168.10.22 subnet_mask: 255.255.255.0 default_gateway: 192.168.10.1 @@ -730,6 +884,17 @@ simulation: type: WebBrowser options: target_url: http://arcd.com/users/ + - ref: data_manipulation_bot + type: DataManipulationBot + options: + port_scan_p_of_success: 0.8 + data_manipulation_p_of_success: 0.8 + payload: "DELETE" + server_ip: 192.168.1.14 + - ref: client_2_database_client + type: DatabaseClient + options: + db_server_ip: 192.168.1.14 services: - ref: client_2_dns_client type: DNSClient diff --git a/tests/assets/configs/test_primaite_session.yaml b/tests/assets/configs/test_primaite_session.yaml index 398aa915..121cc7f1 100644 --- a/tests/assets/configs/test_primaite_session.yaml +++ b/tests/assets/configs/test_primaite_session.yaml @@ -11,8 +11,11 @@ training_config: - defender io_settings: - save_checkpoints: true - checkpoint_interval: 5 + save_agent_actions: true + save_step_metadata: true + save_pcap_logs: true + save_sys_logs: true + game: @@ -29,7 +32,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -55,7 +58,7 @@ agents: frequency: 20 variance: 5 - - ref: client_1_data_manipulation_red_bot + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -599,15 +602,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -630,7 +634,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/assets/configs/train_only_primaite_session.yaml b/tests/assets/configs/train_only_primaite_session.yaml index ced5ae74..71a23989 100644 --- a/tests/assets/configs/train_only_primaite_session.yaml +++ b/tests/assets/configs/train_only_primaite_session.yaml @@ -25,7 +25,7 @@ game: agents: - ref: client_2_green_user team: GREEN - type: GreenWebBrowsingAgent + type: ProbabilisticAgent observation_space: type: UC2GreenObservation action_space: @@ -58,7 +58,7 @@ agents: frequency: 20 variance: 5 - - ref: client_1_data_manipulation_red_bot + - ref: data_manipulation_attacker team: RED type: RedDatabaseCorruptingAgent @@ -600,15 +600,16 @@ simulation: hostname: web_server ip_address: 192.168.1.12 subnet_mask: 255.255.255.0 - default_gateway: 192.168.1.10 + default_gateway: 192.168.1.1 dns_server: 192.168.1.10 services: + - ref: web_server_web_service + type: WebServer + applications: - ref: web_server_database_client type: DatabaseClient options: db_server_ip: 192.168.1.14 - - ref: web_server_web_service - type: WebServer - ref: database_server @@ -631,7 +632,7 @@ simulation: dns_server: 192.168.1.10 services: - ref: backup_service - type: DatabaseBackup + type: FTPServer - ref: security_suite type: server diff --git a/tests/conftest.py b/tests/conftest.py index 5084c339..3a9e2655 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,21 @@ # © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK +from datetime import datetime from pathlib import Path -from typing import Any, Dict, Tuple, Union +from typing import Any, Dict, Optional, Tuple, Union import pytest import yaml +from _pytest.monkeypatch import MonkeyPatch from primaite import getLogger, PRIMAITE_PATHS from primaite.game.agent.actions import ActionManager from primaite.game.agent.interface import AbstractAgent -from primaite.game.agent.observations import ICSObservation, ObservationManager +from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.observations.observations import ICSObservation from primaite.game.agent.rewards import RewardFunction from primaite.game.game import PrimaiteGame from primaite.session.session import PrimaiteSession +from primaite.simulator import SIM_OUTPUT from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -29,6 +33,7 @@ from primaite.simulator.system.services.dns.dns_client import DNSClient from primaite.simulator.system.services.dns.dns_server import DNSServer from primaite.simulator.system.services.service import Service from primaite.simulator.system.services.web_server.web_server import WebServer +from tests import TEST_ASSETS_ROOT from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path ACTION_SPACE_NODE_VALUES = 1 @@ -37,6 +42,21 @@ ACTION_SPACE_NODE_ACTION_VALUES = 1 _LOGGER = getLogger(__name__) +@pytest.fixture(scope="function", autouse=True) +def set_syslog_output_to_true(): + """Will be run before each test.""" + monkeypatch = MonkeyPatch() + monkeypatch.setattr( + SIM_OUTPUT, + "path", + Path(TEST_ASSETS_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")), + ) + monkeypatch.setattr(SIM_OUTPUT, "save_pcap_logs", True) + monkeypatch.setattr(SIM_OUTPUT, "save_sys_logs", True) + + yield + + class TestService(Service): """Test Service class""" @@ -309,7 +329,7 @@ class ControlledAgent(AbstractAgent): ) self.most_recent_action: Tuple[str, Dict] - def get_action(self, obs: None, reward: float = 0.0) -> Tuple[str, Dict]: + def get_action(self, obs: None, timestep: int = 0) -> Tuple[str, Dict]: """Return the agent's most recent action, formatted in CAOS format.""" return self.most_recent_action @@ -403,7 +423,7 @@ def install_stuff_to_sim(sim: Simulation): assert len(sim.network.nodes) == 6 assert len(sim.network.links) == 5 # 5.1: Assert the router is correctly configured - r = sim.network.routers[0] + r = sim.network.router_nodes[0] for i, acl_rule in enumerate(r.acl.acl): if i == 1: assert acl_rule.src_port == acl_rule.dst_port == Port.DNS @@ -478,7 +498,6 @@ def game_and_agent(): ] action_space = ActionManager( - game=game, actions=actions, # ALL POSSIBLE ACTIONS nodes=[ { @@ -510,6 +529,8 @@ def game_and_agent(): reward_function=reward_function, ) - game.agents.append(test_agent) + game.agents["test_agent"] = test_agent + + game.setup_reward_sharing() return (game, test_agent) diff --git a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py index 3934ce5b..84897f9a 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_multi_agent_environment.py @@ -4,7 +4,7 @@ import yaml from ray import air, tune from ray.rllib.algorithms.ppo import PPOConfig -from primaite.config.load import example_config_path +from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteRayMARLEnv @@ -13,7 +13,7 @@ from primaite.session.environment import PrimaiteRayMARLEnv def test_rllib_multi_agent_compatibility(): """Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system.""" - with open(example_config_path(), "r") as f: + with open(data_manipulation_config_path(), "r") as f: cfg = yaml.safe_load(f) game = PrimaiteGame.from_config(cfg) diff --git a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py index 2b12ad98..4c4b8d8d 100644 --- a/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py +++ b/tests/e2e_integration_tests/environments/test_rllib_single_agent_environment.py @@ -6,7 +6,7 @@ import ray import yaml from ray.rllib.algorithms import ppo -from primaite.config.load import example_config_path +from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteRayEnv @@ -14,7 +14,7 @@ from primaite.session.environment import PrimaiteRayEnv @pytest.mark.skip(reason="Slow, reenable later") def test_rllib_single_agent_compatibility(): """Test that the PrimaiteRayEnv class can be used with a single agent RLLIB system.""" - with open(example_config_path(), "r") as f: + with open(data_manipulation_config_path(), "r") as f: cfg = yaml.safe_load(f) game = PrimaiteGame.from_config(cfg) diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 91cf5c1e..83965191 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -6,19 +6,17 @@ import pytest import yaml from stable_baselines3 import PPO -from primaite.config.load import example_config_path +from primaite.config.load import data_manipulation_config_path from primaite.game.game import PrimaiteGame from primaite.session.environment import PrimaiteGymEnv -# @pytest.mark.skip(reason="no way of currently testing this") def test_sb3_compatibility(): """Test that the Gymnasium environment can be used with an SB3 agent.""" - with open(example_config_path(), "r") as f: + with open(data_manipulation_config_path(), "r") as f: cfg = yaml.safe_load(f) - game = PrimaiteGame.from_config(cfg) - gym = PrimaiteGymEnv(game=game) + gym = PrimaiteGymEnv(game_config=cfg) model = PPO("MlpPolicy", gym) model.learn(total_timesteps=1000) diff --git a/tests/e2e_integration_tests/test_primaite_session.py b/tests/e2e_integration_tests/test_primaite_session.py index 7785e4ae..c45a4690 100644 --- a/tests/e2e_integration_tests/test_primaite_session.py +++ b/tests/e2e_integration_tests/test_primaite_session.py @@ -21,16 +21,17 @@ class TestPrimaiteSession: raise AssertionError assert session is not None - assert session.game.simulation - assert len(session.game.agents) == 3 - assert len(session.game.rl_agents) == 1 + assert session.env.game.simulation + assert len(session.env.game.agents) == 3 + assert len(session.env.game.rl_agents) == 1 assert session.policy assert session.env - assert session.game.simulation.network - assert len(session.game.simulation.network.nodes) == 10 + assert session.env.game.simulation.network + assert len(session.env.game.simulation.network.nodes) == 10 + @pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.") @pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True) def test_start_session(self, temp_primaite_session): """Make sure you can go all the way through the session without errors.""" diff --git a/tests/integration_tests/component_creation/test_action_integration.py b/tests/integration_tests/component_creation/test_action_integration.py index f41a57af..e7c9fcc6 100644 --- a/tests/integration_tests/component_creation/test_action_integration.py +++ b/tests/integration_tests/component_creation/test_action_integration.py @@ -12,6 +12,8 @@ def test_passing_actions_down(monkeypatch) -> None: sim = Simulation() pc1 = Computer(hostname="PC-1", ip_address="10.10.1.1", subnet_mask="255.255.255.0") + pc1.start_up_duration = 0 + pc1.power_on() pc2 = Computer(hostname="PC-2", ip_address="10.10.1.2", subnet_mask="255.255.255.0") srv = Server(hostname="WEBSERVER", ip_address="10.10.1.100", subnet_mask="255.255.255.0") s1 = Switch(hostname="switch1") diff --git a/tests/integration_tests/configuration_file_parsing/__init__.py b/tests/integration_tests/configuration_file_parsing/__init__.py new file mode 100644 index 00000000..be21c036 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/__init__.py @@ -0,0 +1,21 @@ +from pathlib import Path +from typing import Union + +import yaml + +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + +DMZ_NETWORK = TEST_ASSETS_ROOT / "configs/dmz_network.yaml" + +BASIC_FIREWALL = TEST_ASSETS_ROOT / "configs/basic_firewall.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) diff --git a/tests/integration_tests/configuration_file_parsing/nodes/__init__.py b/tests/integration_tests/configuration_file_parsing/nodes/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/__init__.py b/tests/integration_tests/configuration_file_parsing/nodes/network/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py new file mode 100644 index 00000000..fc6e05ec --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_firewall_config.py @@ -0,0 +1,135 @@ +from ipaddress import IPv4Address + +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.firewall import Firewall +from primaite.simulator.network.hardware.nodes.network.router import ACLAction +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from tests.integration_tests.configuration_file_parsing import BASIC_FIREWALL, DMZ_NETWORK, load_config + + +@pytest.fixture(scope="function") +def dmz_config() -> Network: + game = load_config(DMZ_NETWORK) + return game.simulation.network + + +@pytest.fixture(scope="function") +def basic_firewall_config() -> Network: + game = load_config(BASIC_FIREWALL) + return game.simulation.network + + +def test_firewall_is_in_configuration(dmz_config): + """Test that the firewall exists in the configuration file.""" + network: Network = dmz_config + + firewall: Firewall = network.get_node_by_hostname("firewall") + + assert firewall + assert firewall.operating_state == NodeOperatingState.ON + + +def test_firewall_routes_are_correctly_added(dmz_config): + """Test that the firewall routes have been correctly added to and configured in the network.""" + network: Network = dmz_config + + firewall: Firewall = network.get_node_by_hostname("firewall") + client_1: Computer = network.get_node_by_hostname("client_1") + dmz_server: Server = network.get_node_by_hostname("dmz_server") + external_computer: Computer = network.get_node_by_hostname("external_computer") + external_server: Server = network.get_node_by_hostname("external_server") + + # there should be a route to client_1 + assert firewall.route_table.find_best_route(client_1.network_interface[1].ip_address) + assert dmz_server.ping(client_1.network_interface[1].ip_address) + assert external_computer.ping(client_1.network_interface[1].ip_address) + assert external_server.ping(client_1.network_interface[1].ip_address) + + # client_1 should be able to ping other nodes + assert client_1.ping(dmz_server.network_interface[1].ip_address) + assert client_1.ping(external_computer.network_interface[1].ip_address) + assert client_1.ping(external_server.network_interface[1].ip_address) + + +def test_firewall_acl_rules_correctly_added(dmz_config): + """ + Test that makes sure that the firewall ACLs have been configured onto the firewall + node via configuration file. + """ + firewall: Firewall = dmz_config.get_node_by_hostname("firewall") + + # ICMP and ARP should be allowed internal_inbound + assert firewall.internal_inbound_acl.num_rules == 2 + assert firewall.internal_inbound_acl.acl[22].action == ACLAction.PERMIT + assert firewall.internal_inbound_acl.acl[22].src_port == Port.ARP + assert firewall.internal_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_inbound_acl.acl[23].action == ACLAction.PERMIT + assert firewall.internal_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_inbound_acl.implicit_action == ACLAction.DENY + + # ICMP and ARP should be allowed internal_outbound + assert firewall.internal_outbound_acl.num_rules == 2 + assert firewall.internal_outbound_acl.acl[22].action == ACLAction.PERMIT + assert firewall.internal_outbound_acl.acl[22].src_port == Port.ARP + assert firewall.internal_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.internal_outbound_acl.acl[23].action == ACLAction.PERMIT + assert firewall.internal_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.internal_outbound_acl.implicit_action == ACLAction.DENY + + # ICMP and ARP should be allowed dmz_inbound + assert firewall.dmz_inbound_acl.num_rules == 2 + assert firewall.dmz_inbound_acl.acl[22].action == ACLAction.PERMIT + assert firewall.dmz_inbound_acl.acl[22].src_port == Port.ARP + assert firewall.dmz_inbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_inbound_acl.acl[23].action == ACLAction.PERMIT + assert firewall.dmz_inbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_inbound_acl.implicit_action == ACLAction.DENY + + # ICMP and ARP should be allowed dmz_outbound + assert firewall.dmz_outbound_acl.num_rules == 2 + assert firewall.dmz_outbound_acl.acl[22].action == ACLAction.PERMIT + assert firewall.dmz_outbound_acl.acl[22].src_port == Port.ARP + assert firewall.dmz_outbound_acl.acl[22].dst_port == Port.ARP + assert firewall.dmz_outbound_acl.acl[23].action == ACLAction.PERMIT + assert firewall.dmz_outbound_acl.acl[23].protocol == IPProtocol.ICMP + assert firewall.dmz_outbound_acl.implicit_action == ACLAction.DENY + + # ICMP and ARP should be allowed external_inbound + assert firewall.external_inbound_acl.num_rules == 1 + assert firewall.external_inbound_acl.acl[22].action == ACLAction.PERMIT + assert firewall.external_inbound_acl.acl[22].src_port == Port.ARP + assert firewall.external_inbound_acl.acl[22].dst_port == Port.ARP + # external_inbound should have implicit action PERMIT + # ICMP does not have a provided ACL Rule but implicit action should allow anything + assert firewall.external_inbound_acl.implicit_action == ACLAction.PERMIT + + # ICMP and ARP should be allowed external_outbound + assert firewall.external_outbound_acl.num_rules == 1 + assert firewall.external_outbound_acl.acl[22].action == ACLAction.PERMIT + assert firewall.external_outbound_acl.acl[22].src_port == Port.ARP + assert firewall.external_outbound_acl.acl[22].dst_port == Port.ARP + # external_outbound should have implicit action PERMIT + # ICMP does not have a provided ACL Rule but implicit action should allow anything + assert firewall.external_outbound_acl.implicit_action == ACLAction.PERMIT + + +def test_firewall_with_no_dmz_port(basic_firewall_config): + """ + Test to check that: + - the DMZ port can be ignored i.e. is optional. + - the external_outbound_acl and external_inbound_acl are optional + """ + network: Network = basic_firewall_config + + firewall: Firewall = network.get_node_by_hostname("firewall") + + assert firewall.dmz_port.ip_address == IPv4Address("127.0.0.1") + + assert firewall.external_outbound_acl.num_rules == 0 + assert firewall.external_inbound_acl.num_rules == 0 diff --git a/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py new file mode 100644 index 00000000..4382cc30 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/nodes/network/test_router_config.py @@ -0,0 +1,69 @@ +import pytest + +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.network_layer import IPProtocol +from primaite.simulator.network.transmission.transport_layer import Port +from tests.integration_tests.configuration_file_parsing import DMZ_NETWORK, load_config + + +@pytest.fixture(scope="function") +def dmz_config() -> Network: + game = load_config(DMZ_NETWORK) + return game.simulation.network + + +def test_router_is_in_configuration(dmz_config): + """Test that the router exists in the configuration file.""" + network: Network = dmz_config + + router_1: Router = network.get_node_by_hostname("router_1") + + assert router_1 + assert router_1.operating_state == NodeOperatingState.ON + + +def test_router_routes_are_correctly_added(dmz_config): + """Test that makes sure that router routes have been added from the configuration file.""" + network: Network = dmz_config + + router_1: Router = network.get_node_by_hostname("router_1") + client_1: Computer = network.get_node_by_hostname("client_1") + dmz_server: Server = network.get_node_by_hostname("dmz_server") + external_computer: Computer = network.get_node_by_hostname("external_computer") + external_server: Server = network.get_node_by_hostname("external_server") + + # there should be a route to dmz_server + assert router_1.route_table.find_best_route(dmz_server.network_interface[1].ip_address) + assert client_1.ping(dmz_server.network_interface[1].ip_address) + assert external_computer.ping(dmz_server.network_interface[1].ip_address) + assert external_server.ping(dmz_server.network_interface[1].ip_address) + + # there should be a route to external_computer + assert router_1.route_table.find_best_route(external_computer.network_interface[1].ip_address) + assert client_1.ping(external_computer.network_interface[1].ip_address) + assert dmz_server.ping(external_computer.network_interface[1].ip_address) + assert external_server.ping(external_computer.network_interface[1].ip_address) + + # there should be a route to external_server + assert router_1.route_table.find_best_route(external_server.network_interface[1].ip_address) + assert client_1.ping(external_server.network_interface[1].ip_address) + assert dmz_server.ping(external_server.network_interface[1].ip_address) + assert external_computer.ping(external_server.network_interface[1].ip_address) + + +def test_router_acl_rules_correctly_added(dmz_config): + """Test that makes sure that the router ACLs have been configured onto the router node via configuration file.""" + router_1: Router = dmz_config.get_node_by_hostname("router_1") + + # ICMP and ARP should be allowed + assert router_1.acl.num_rules == 2 + assert router_1.acl.acl[22].action == ACLAction.PERMIT + assert router_1.acl.acl[22].src_port == Port.ARP + assert router_1.acl.acl[22].dst_port == Port.ARP + assert router_1.acl.acl[23].action == ACLAction.PERMIT + assert router_1.acl.acl[23].protocol == IPProtocol.ICMP + assert router_1.acl.implicit_action == ACLAction.DENY diff --git a/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py new file mode 100644 index 00000000..174bd0c0 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/nodes/test_node_config.py @@ -0,0 +1,45 @@ +from primaite.config.load import data_manipulation_config_path +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from tests.integration_tests.configuration_file_parsing import BASIC_CONFIG, DMZ_NETWORK, load_config + + +def test_example_config(): + """Test that the example config can be parsed properly.""" + game = load_config(data_manipulation_config_path()) + network: Network = game.simulation.network + + assert len(network.nodes) == 10 # 10 nodes in example network + assert len(network.router_nodes) == 1 # 1 router in network + assert len(network.switch_nodes) == 2 # 2 switches in network + assert len(network.server_nodes) == 5 # 5 servers in network + + +def test_dmz_config(): + """Test that the DMZ network config can be parsed properly.""" + game = load_config(DMZ_NETWORK) + + network: Network = game.simulation.network + + assert len(network.nodes) == 9 # 9 nodes in network + assert len(network.router_nodes) == 1 # 1 router in network + assert len(network.firewall_nodes) == 1 # 1 firewall in network + assert len(network.switch_nodes) == 3 # 3 switches in network + assert len(network.server_nodes) == 2 # 2 servers in network + + +def test_basic_config(): + """Test that the basic_switched_network config can be parsed properly.""" + game = load_config(BASIC_CONFIG) + network: Network = game.simulation.network + assert len(network.nodes) == 4 # 4 nodes in network + + client_1: Computer = network.get_node_by_hostname("client_1") + assert client_1.operating_state == NodeOperatingState.ON + client_2: Computer = network.get_node_by_hostname("client_2") + assert client_2.operating_state == NodeOperatingState.ON + + # client 3 should not be online + client_3: Computer = network.get_node_by_hostname("client_3") + assert client_3.operating_state == NodeOperatingState.OFF diff --git a/tests/integration_tests/game_configuration.py b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py similarity index 89% rename from tests/integration_tests/game_configuration.py rename to tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py index 3bd870e3..a5fcb372 100644 --- a/tests/integration_tests/game_configuration.py +++ b/tests/integration_tests/configuration_file_parsing/software_installation_and_configuration.py @@ -4,9 +4,10 @@ from typing import Union import yaml -from primaite.config.load import example_config_path -from primaite.game.agent.data_manipulation_bot import DataManipulationAgent -from primaite.game.agent.interface import ProxyAgent, RandomAgent +from primaite.config.load import data_manipulation_config_path +from primaite.game.agent.interface import ProxyAgent +from primaite.game.agent.scripted_agents.data_manipulation_bot import DataManipulationAgent +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent from primaite.game.game import APPLICATION_TYPES_MAPPING, PrimaiteGame, SERVICE_TYPES_MAPPING from primaite.simulator.network.container import Network from primaite.simulator.network.hardware.nodes.host.computer import Computer @@ -37,32 +38,32 @@ def load_config(config_path: Union[str, Path]) -> PrimaiteGame: def test_example_config(): """Test that the example config can be parsed properly.""" - game = load_config(example_config_path()) + game = load_config(data_manipulation_config_path()) assert len(game.agents) == 4 # red, blue and 2 green agents # green agent 1 - assert game.agents[0].agent_name == "client_2_green_user" - assert isinstance(game.agents[0], RandomAgent) + assert "client_2_green_user" in game.agents + assert isinstance(game.agents["client_2_green_user"], ProbabilisticAgent) # green agent 2 - assert game.agents[1].agent_name == "client_1_green_user" - assert isinstance(game.agents[1], RandomAgent) + assert "client_1_green_user" in game.agents + assert isinstance(game.agents["client_1_green_user"], ProbabilisticAgent) # red agent - assert game.agents[2].agent_name == "client_1_data_manipulation_red_bot" - assert isinstance(game.agents[2], DataManipulationAgent) + assert "data_manipulation_attacker" in game.agents + assert isinstance(game.agents["data_manipulation_attacker"], DataManipulationAgent) # blue agent - assert game.agents[3].agent_name == "defender" - assert isinstance(game.agents[3], ProxyAgent) + assert "defender" in game.agents + assert isinstance(game.agents["defender"], ProxyAgent) network: Network = game.simulation.network assert len(network.nodes) == 10 # 10 nodes in example network - assert len(network.routers) == 1 # 1 router in network - assert len(network.switches) == 2 # 2 switches in network - assert len(network.servers) == 5 # 5 servers in network + assert len(network.router_nodes) == 1 # 1 router in network + assert len(network.switch_nodes) == 2 # 2 switches in network + assert len(network.server_nodes) == 5 # 5 servers in network def test_node_software_install(): diff --git a/tests/integration_tests/configuration_file_parsing/test_game_options_config.py b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py new file mode 100644 index 00000000..adbbf2b5 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_game_options_config.py @@ -0,0 +1,25 @@ +from pathlib import Path +from typing import Union + +import yaml + +from primaite.config.load import data_manipulation_config_path +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +def test_thresholds(): + """Test that the game options can be parsed correctly.""" + game = load_config(data_manipulation_config_path()) + + assert game.options.thresholds is not None diff --git a/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py b/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py new file mode 100644 index 00000000..5c9b0cb9 --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_no_nodes_links_agents_config.py @@ -0,0 +1,19 @@ +import yaml + +from primaite.game.game import PrimaiteGame +from tests import TEST_ASSETS_ROOT + +CONFIG_FILE = TEST_ASSETS_ROOT / "configs" / "no_nodes_links_agents_network.yaml" + + +def test_no_nodes_links_agents_config(): + """Tests PrimaiteGame can be created from config file where there are no nodes, links, agents in the config file.""" + with open(CONFIG_FILE, "r") as f: + cfg = yaml.safe_load(f) + + game = PrimaiteGame.from_config(cfg) + + network = game.simulation.network + + assert len(network.nodes) == 0 + assert len(network.links) == 0 diff --git a/tests/integration_tests/game_layer/observations/__init__.py b/tests/integration_tests/game_layer/observations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/game_layer/observations/test_acl_observations.py b/tests/integration_tests/game_layer/observations/test_acl_observations.py new file mode 100644 index 00000000..93867edd --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_acl_observations.py @@ -0,0 +1,66 @@ +import pytest + +from primaite.game.agent.observations.observations import AclObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.services.ntp.ntp_client import NTPClient +from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_acl_observations(simulation): + """Test the ACL rule observations.""" + router: Router = simulation.network.get_node_by_hostname("router_1") + client_1: Computer = simulation.network.get_node_by_hostname("client_1") + server: Computer = simulation.network.get_node_by_hostname("server_1") + + # quick set up of ntp + client_1.software_manager.install(NTPClient) + ntp_client: NTPClient = client_1.software_manager.software.get("NTPClient") + ntp_client.configure(server.network_interface.get(1).ip_address) + server.software_manager.install(NTPServer) + + # add router acl rule + router.acl.add_rule(action=ACLAction.PERMIT, dst_port=Port.NTP, src_port=Port.NTP, position=1) + + acl_obs = AclObservation( + where=["network", "nodes", router.hostname, "acl", "acl"], + node_ip_to_id={}, + ports=["NTP", "HTTP", "POSTGRES_SERVER"], + protocols=["TCP", "UDP", "ICMP"], + ) + + observation_space = acl_obs.observe(simulation.describe_state()) + assert observation_space.get(1) is not None + rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP + assert rule_obs.get("position") == 0 # rule was put at position 1 (0 because counting from 1 instead of 1) + assert rule_obs.get("permission") == 1 # permit = 1 deny = 2 + assert rule_obs.get("source_node_id") == 1 # applies to all source nodes + assert rule_obs.get("dest_node_id") == 1 # applies to all destination nodes + assert rule_obs.get("source_port") == 2 # NTP port is mapped to value 2 (1 = ALL, so 1+1 = 2 quik mafs) + assert rule_obs.get("dest_port") == 2 # NTP port is mapped to value 2 + assert rule_obs.get("protocol") == 1 # 1 = No Protocol + + router.acl.remove_rule(1) + + observation_space = acl_obs.observe(simulation.describe_state()) + assert observation_space.get(1) is not None + rule_obs = observation_space.get(1) # this is the ACL Rule added to allow NTP + assert rule_obs.get("position") == 0 + assert rule_obs.get("permission") == 0 + assert rule_obs.get("source_node_id") == 0 + assert rule_obs.get("dest_node_id") == 0 + assert rule_obs.get("source_port") == 0 + assert rule_obs.get("dest_port") == 0 + assert rule_obs.get("protocol") == 0 diff --git a/tests/integration_tests/game_layer/observations/test_file_system_observations.py b/tests/integration_tests/game_layer/observations/test_file_system_observations.py new file mode 100644 index 00000000..35bb95fd --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_file_system_observations.py @@ -0,0 +1,70 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.file_system_observations import FileObservation, FolderObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_file_observation(simulation): + """Test the file observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # create a file on the pc + file = pc.file_system.create_file(file_name="dog.png") + + dog_file_obs = FileObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "root", "files", "dog.png"] + ) + + assert dog_file_obs.space["health_status"] == spaces.Discrete(6) + + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # good initial + + file.corrupt() + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # scan file so this changes + + file.scan() + file.apply_timestep(0) # apply time step + observation_state = dog_file_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 3 # corrupted + + +def test_folder_observation(simulation): + """Test the folder observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # create a file and folder on the pc + folder = pc.file_system.create_folder("test_folder") + file = pc.file_system.create_file(file_name="dog.png", folder_name="test_folder") + + root_folder_obs = FolderObservation( + where=["network", "nodes", pc.hostname, "file_system", "folders", "test_folder"] + ) + + assert root_folder_obs.space["health_status"] == spaces.Discrete(6) + + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("FILES") is not None + assert observation_state.get("health_status") == 1 + + file.corrupt() # corrupt just the file + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 # scan folder to change this + + folder.scan() + for i in range(folder.scan_duration + 1): + folder.apply_timestep(i) # apply as many timesteps as needed for a scan + + observation_state = root_folder_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 3 # file is corrupt therefore folder is corrupted too diff --git a/tests/integration_tests/game_layer/observations/test_link_observations.py b/tests/integration_tests/game_layer/observations/test_link_observations.py new file mode 100644 index 00000000..bfe4d5cc --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_link_observations.py @@ -0,0 +1,73 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.observations import LinkObservation +from primaite.simulator.network.container import Network +from primaite.simulator.network.hardware.base import Link, Node +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation() -> Simulation: + sim = Simulation() + + network = Network() + + # Create Computer + computer = Computer( + hostname="computer", + ip_address="192.168.1.2", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + computer.power_on() + + # Create Server + server = Server( + hostname="server", + ip_address="192.168.1.3", + subnet_mask="255.255.255.0", + default_gateway="192.168.1.1", + start_up_duration=0, + ) + server.power_on() + + # Connect Computer and Server + network.connect(computer.network_interface[1], server.network_interface[1]) + + # Should be linked + assert next(iter(network.links.values())).is_up + + assert computer.ping(server.network_interface.get(1).ip_address) + + # set simulation network as example network + sim.network = network + + return sim + + +def test_link_observation(simulation): + """Test the link observation.""" + # get a link + link: Link = next(iter(simulation.network.links.values())) + + computer: Computer = simulation.network.get_node_by_hostname("computer") + server: Server = simulation.network.get_node_by_hostname("server") + + simulation.apply_timestep(0) # some pings when network was made - reset with apply timestep + + link_obs = LinkObservation(where=["network", "links", link.uuid]) + + assert link_obs.space["PROTOCOLS"]["ALL"] == spaces.Discrete(11) # test that the spaces are 0-10 including 0 and 10 + + observation_state = link_obs.observe(simulation.describe_state()) + assert observation_state.get("PROTOCOLS") is not None + assert observation_state["PROTOCOLS"]["ALL"] == 0 + + computer.ping(server.network_interface.get(1).ip_address) + + observation_state = link_obs.observe(simulation.describe_state()) + assert observation_state["PROTOCOLS"]["ALL"] == 1 diff --git a/tests/integration_tests/game_layer/observations/test_nic_observations.py b/tests/integration_tests/game_layer/observations/test_nic_observations.py new file mode 100644 index 00000000..332bc1f7 --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_nic_observations.py @@ -0,0 +1,97 @@ +from pathlib import Path +from typing import Union + +import pytest +import yaml +from gymnasium import spaces + +from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.game.game import PrimaiteGame +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.network.hardware.nodes.host.host_node import NIC +from primaite.simulator.network.nmne import CAPTURE_NMNE +from primaite.simulator.sim_container import Simulation +from tests import TEST_ASSETS_ROOT + +BASIC_CONFIG = TEST_ASSETS_ROOT / "configs/basic_switched_network.yaml" + + +def load_config(config_path: Union[str, Path]) -> PrimaiteGame: + """Returns a PrimaiteGame object which loads the contents of a given yaml path.""" + with open(config_path, "r") as f: + cfg = yaml.safe_load(f) + + return PrimaiteGame.from_config(cfg) + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_nic(simulation): + """Test the NIC observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + nic: NIC = pc.network_interface[1] + + nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + + assert nic_obs.space["nic_status"] == spaces.Discrete(3) + assert nic_obs.space["NMNE"]["inbound"] == spaces.Discrete(4) + assert nic_obs.space["NMNE"]["outbound"] == spaces.Discrete(4) + + observation_state = nic_obs.observe(simulation.describe_state()) + assert observation_state.get("nic_status") == 1 # enabled + assert observation_state.get("NMNE") is not None + assert observation_state["NMNE"].get("inbound") == 0 + assert observation_state["NMNE"].get("outbound") == 0 + + nic.disable() + observation_state = nic_obs.observe(simulation.describe_state()) + assert observation_state.get("nic_status") == 2 # disabled + + +def test_nic_categories(simulation): + """Test the NIC observation nmne count categories.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + nic_obs = NicObservation(where=["network", "nodes", pc.hostname, "NICs", 1]) + + assert nic_obs.high_nmne_threshold == 10 # default + assert nic_obs.med_nmne_threshold == 5 # default + assert nic_obs.low_nmne_threshold == 0 # default + + nic_obs = NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=3, + med_nmne_threshold=6, + high_nmne_threshold=9, + ) + + assert nic_obs.high_nmne_threshold == 9 + assert nic_obs.med_nmne_threshold == 6 + assert nic_obs.low_nmne_threshold == 3 + + with pytest.raises(Exception): + # should throw an error + NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=9, + med_nmne_threshold=6, + high_nmne_threshold=9, + ) + + with pytest.raises(Exception): + # should throw an error + NicObservation( + where=["network", "nodes", pc.hostname, "NICs", 1], + low_nmne_threshold=3, + med_nmne_threshold=9, + high_nmne_threshold=9, + ) diff --git a/tests/integration_tests/game_layer/observations/test_node_observations.py b/tests/integration_tests/game_layer/observations/test_node_observations.py new file mode 100644 index 00000000..dce05b6a --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_node_observations.py @@ -0,0 +1,46 @@ +import copy +from uuid import uuid4 + +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.node_observations import NodeObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_node_observation(simulation): + """Test a Node observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + + node_obs = NodeObservation(where=["network", "nodes", pc.hostname]) + + assert node_obs.space["operating_status"] == spaces.Discrete(5) + + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 1 # computer is on + + assert observation_state.get("SERVICES") is not None + assert observation_state.get("FOLDERS") is not None + assert observation_state.get("NICS") is not None + + # turn off computer + pc.power_off() + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 4 # shutting down + + for i in range(pc.shut_down_duration + 1): + pc.apply_timestep(i) + + observation_state = node_obs.observe(simulation.describe_state()) + assert observation_state.get("operating_status") == 2 diff --git a/tests/integration_tests/game_layer/observations/test_software_observations.py b/tests/integration_tests/game_layer/observations/test_software_observations.py new file mode 100644 index 00000000..4ae0701e --- /dev/null +++ b/tests/integration_tests/game_layer/observations/test_software_observations.py @@ -0,0 +1,70 @@ +import pytest +from gymnasium import spaces + +from primaite.game.agent.observations.software_observation import ApplicationObservation, ServiceObservation +from primaite.simulator.network.hardware.nodes.host.computer import Computer +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.applications.web_browser import WebBrowser +from primaite.simulator.system.services.ntp.ntp_server import NTPServer + + +@pytest.fixture(scope="function") +def simulation(example_network) -> Simulation: + sim = Simulation() + + # set simulation network as example network + sim.network = example_network + + return sim + + +def test_service_observation(simulation): + """Test the service observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # install software on the computer + pc.software_manager.install(NTPServer) + + ntp_server = pc.software_manager.software.get("NTPServer") + assert ntp_server + + service_obs = ServiceObservation(where=["network", "nodes", pc.hostname, "services", "NTPServer"]) + + assert service_obs.space["operating_status"] == spaces.Discrete(7) + assert service_obs.space["health_status"] == spaces.Discrete(5) + + observation_state = service_obs.observe(simulation.describe_state()) + + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 1 # running + + ntp_server.restart() + observation_state = service_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 6 # resetting + + +def test_application_observation(simulation): + """Test the application observation.""" + pc: Computer = simulation.network.get_node_by_hostname("client_1") + # install software on the computer + pc.software_manager.install(DatabaseClient) + + web_browser: WebBrowser = pc.software_manager.software.get("WebBrowser") + assert web_browser + + app_obs = ApplicationObservation(where=["network", "nodes", pc.hostname, "applications", "WebBrowser"]) + + web_browser.close() + observation_state = app_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 0 + assert observation_state.get("operating_status") == 2 # stopped + assert observation_state.get("num_executions") == 0 + + web_browser.run() + web_browser.scan() # scan to update health status + web_browser.get_webpage("test") + observation_state = app_obs.observe(simulation.describe_state()) + assert observation_state.get("health_status") == 1 + assert observation_state.get("operating_status") == 1 # running + assert observation_state.get("num_executions") == 1 diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 8911632c..740fb491 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -10,28 +10,14 @@ # 4. Check that the simulation has changed in the way that I expect. # 5. Repeat for all actions. -from typing import Dict, Tuple +from typing import Tuple import pytest -from primaite.game.agent.actions import ActionManager -from primaite.game.agent.interface import AbstractAgent, ProxyAgent -from primaite.game.agent.observations import ICSObservation, ObservationManager -from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame from primaite.simulator.file_system.file_system_item_abc import FileSystemItemHealthStatus -from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState -from primaite.simulator.network.hardware.nodes.host.computer import Computer -from primaite.simulator.network.hardware.nodes.host.server import Server -from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router -from primaite.simulator.network.hardware.nodes.network.switch import Switch -from primaite.simulator.network.transmission.network_layer import IPProtocol -from primaite.simulator.network.transmission.transport_layer import Port -from primaite.simulator.sim_container import Simulation from primaite.simulator.system.applications.web_browser import WebBrowser -from primaite.simulator.system.services.dns.dns_client import DNSClient -from primaite.simulator.system.services.dns.dns_server import DNSServer -from primaite.simulator.system.services.web_server.web_server import WebServer from primaite.simulator.system.software import SoftwareHealthState diff --git a/tests/integration_tests/game_layer/test_observations.py b/tests/integration_tests/game_layer/test_observations.py index d1301759..f52b52f7 100644 --- a/tests/integration_tests/game_layer/test_observations.py +++ b/tests/integration_tests/game_layer/test_observations.py @@ -1,6 +1,6 @@ from gymnasium import spaces -from primaite.game.agent.observations import FileObservation +from primaite.game.agent.observations.file_system_observations import FileObservation from primaite.simulator.network.hardware.nodes.host.computer import Computer from primaite.simulator.sim_container import Simulation diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index fd8a89a4..cfd013bc 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -1,7 +1,16 @@ -from primaite.game.agent.rewards import WebpageUnavailablePenalty +import yaml + +from primaite.game.agent.interface import AgentActionHistoryItem +from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty +from primaite.game.game import PrimaiteGame +from primaite.session.environment import PrimaiteGymEnv +from primaite.simulator.network.hardware.nodes.host.server import Server from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port +from primaite.simulator.system.applications.database_client import DatabaseClient +from primaite.simulator.system.services.database.database_service import DatabaseService +from tests import TEST_ASSETS_ROOT from tests.conftest import ControlledAgent @@ -35,3 +44,77 @@ def test_WebpageUnavailablePenalty(game_and_agent): agent.store_action(action) game.step() assert agent.reward_function.current_reward == -0.7 + + +def test_uc2_rewards(game_and_agent): + """Test that the reward component correctly applies a penalty when the selected client cannot reach the database.""" + game, agent = game_and_agent + agent: ControlledAgent + + server_1: Server = game.simulation.network.get_node_by_hostname("server_1") + server_1.software_manager.install(DatabaseService) + db_service = server_1.software_manager.software.get("DatabaseService") + db_service.start() + + client_1 = game.simulation.network.get_node_by_hostname("client_1") + client_1.software_manager.install(DatabaseClient) + db_client: DatabaseClient = client_1.software_manager.software.get("DatabaseClient") + db_client.configure(server_ip_address=server_1.network_interface[1].ip_address) + db_client.run() + + router: Router = game.simulation.network.get_node_by_hostname("router") + router.acl.add_rule(ACLAction.PERMIT, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER, position=2) + + comp = GreenAdminDatabaseUnreachablePenalty("client_1") + + response = db_client.apply_request( + [ + "execute", + ] + ) + state = game.get_sim_state() + reward_value = comp.calculate( + state, + last_action_response=AgentActionHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + ), + ) + assert reward_value == 1.0 + + router.acl.remove_rule(position=2) + + db_client.apply_request( + [ + "execute", + ] + ) + state = game.get_sim_state() + reward_value = comp.calculate( + state, + last_action_response=AgentActionHistoryItem( + timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response + ), + ) + assert reward_value == -1.0 + + +def test_shared_reward(): + CFG_PATH = TEST_ASSETS_ROOT / "configs/shared_rewards.yaml" + with open(CFG_PATH, "r") as f: + cfg = yaml.safe_load(f) + + env = PrimaiteGymEnv(game_config=cfg) + + env.reset() + + order = env.game._reward_calculation_order + assert order.index("defender") > order.index("client_1_green_user") + assert order.index("defender") > order.index("client_2_green_user") + + for step in range(256): + act = env.action_space.sample() + env.step(act) + g1_reward = env.game.agents["client_1_green_user"].reward_function.current_reward + g2_reward = env.game.agents["client_2_green_user"].reward_function.current_reward + blue_reward = env.game.agents["defender"].reward_function.current_reward + assert blue_reward == g1_reward + g2_reward diff --git a/tests/integration_tests/network/test_capture_nmne.py b/tests/integration_tests/network/test_capture_nmne.py new file mode 100644 index 00000000..9efc70f7 --- /dev/null +++ b/tests/integration_tests/network/test_capture_nmne.py @@ -0,0 +1,198 @@ +from primaite.game.agent.observations.nic_observations import NicObservation +from primaite.simulator.network.hardware.nodes.host.server import Server +from primaite.simulator.network.nmne import set_nmne_config +from primaite.simulator.sim_container import Simulation +from primaite.simulator.system.applications.database_client import DatabaseClient + + +def test_capture_nmne(uc2_network): + """ + Conducts a test to verify that Malicious Network Events (MNEs) are correctly captured. + + This test involves a web server querying a database server and checks if the MNEs are captured + based on predefined keywords in the network configuration. Specifically, it checks the capture + of the "DELETE" SQL command as a malicious network event. + """ + web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa + db_client.connect() + + db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa + + web_server_nic = web_server.network_interface[1] + db_server_nic = db_server.network_interface[1] + + # Set the NMNE configuration to capture DELETE queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + set_nmne_config(nmne_config) + + # Assert that initially, there are no captured MNEs on both web and database servers + assert web_server_nic.nmne == {} + assert db_server_nic.nmne == {} + + # Perform a "SELECT" query + db_client.query("SELECT") + + # Check that it does not trigger an MNE capture. + assert web_server_nic.nmne == {} + assert db_server_nic.nmne == {} + + # Perform a "DELETE" query + db_client.query("DELETE") + + # Check that the web server's outbound interface and the database server's inbound interface register the MNE + assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 1}}}} + assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 1}}}} + + # Perform another "SELECT" query + db_client.query("SELECT") + + # Check that no additional MNEs are captured + assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 1}}}} + assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 1}}}} + + # Perform another "DELETE" query + db_client.query("DELETE") + + # Check that the web server and database server interfaces register an additional MNE + assert web_server_nic.nmne == {"direction": {"outbound": {"keywords": {"*": 2}}}} + assert db_server_nic.nmne == {"direction": {"inbound": {"keywords": {"*": 2}}}} + + +def test_describe_state_nmne(uc2_network): + """ + Conducts a test to verify that Malicious Network Events (MNEs) are correctly represented in the nic state. + + This test involves a web server querying a database server and checks if the MNEs are captured + based on predefined keywords in the network configuration. Specifically, it checks the capture + of the "DELETE" SQL command as a malicious network event. It also checks that running describe_state + only shows MNEs since the last time describe_state was called. + """ + web_server: Server = uc2_network.get_node_by_hostname("web_server") # noqa + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] # noqa + db_client.connect() + + db_server: Server = uc2_network.get_node_by_hostname("database_server") # noqa + + web_server_nic = web_server.network_interface[1] + db_server_nic = db_server.network_interface[1] + + # Set the NMNE configuration to capture DELETE queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + set_nmne_config(nmne_config) + + # Assert that initially, there are no captured MNEs on both web and database servers + web_server_nic_state = web_server_nic.describe_state() + db_server_nic_state = db_server_nic.describe_state() + uc2_network.apply_timestep(timestep=0) + assert web_server_nic_state["nmne"] == {} + assert db_server_nic_state["nmne"] == {} + + # Perform a "SELECT" query + db_client.query("SELECT") + + # Check that it does not trigger an MNE capture. + web_server_nic_state = web_server_nic.describe_state() + db_server_nic_state = db_server_nic.describe_state() + uc2_network.apply_timestep(timestep=0) + assert web_server_nic_state["nmne"] == {} + assert db_server_nic_state["nmne"] == {} + + # Perform a "DELETE" query + db_client.query("DELETE") + + # Check that the web server's outbound interface and the database server's inbound interface register the MNE + web_server_nic_state = web_server_nic.describe_state() + db_server_nic_state = db_server_nic.describe_state() + uc2_network.apply_timestep(timestep=0) + assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}} + assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}} + + # Perform another "SELECT" query + db_client.query("SELECT") + + # Check that no additional MNEs are captured + web_server_nic_state = web_server_nic.describe_state() + db_server_nic_state = db_server_nic.describe_state() + uc2_network.apply_timestep(timestep=0) + assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 1}}}} + assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 1}}}} + + # Perform another "DELETE" query + db_client.query("DELETE") + + # Check that the web server and database server interfaces register an additional MNE + web_server_nic_state = web_server_nic.describe_state() + db_server_nic_state = db_server_nic.describe_state() + uc2_network.apply_timestep(timestep=0) + assert web_server_nic_state["nmne"] == {"direction": {"outbound": {"keywords": {"*": 2}}}} + assert db_server_nic_state["nmne"] == {"direction": {"inbound": {"keywords": {"*": 2}}}} + + +def test_capture_nmne_observations(uc2_network): + """ + Tests the NicObservation class's functionality within a simulated network environment. + + This test ensures the observation space, as defined by instances of NicObservation, accurately reflects the + number of MNEs detected based on network activities over multiple iterations. + + The test employs a series of "DELETE" SQL operations, considered as MNEs, to validate the dynamic update + and accuracy of the observation space related to network interface conditions. It confirms that the + observed NIC states match expected MNE activity levels. + """ + # Initialise a new Simulation instance and assign the test network to it. + sim = Simulation() + sim.network = uc2_network + + web_server: Server = uc2_network.get_node_by_hostname("web_server") + db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"] + db_client.connect() + + # Set the NMNE configuration to capture DELETE queries as MNEs + nmne_config = { + "capture_nmne": True, # Enable the capture of MNEs + "nmne_capture_keywords": ["DELETE"], # Specify "DELETE" SQL command as a keyword for MNE detection + } + + # Apply the NMNE configuration settings + set_nmne_config(nmne_config) + + # Define observations for the NICs of the database and web servers + db_server_nic_obs = NicObservation(where=["network", "nodes", "database_server", "NICs", 1]) + web_server_nic_obs = NicObservation(where=["network", "nodes", "web_server", "NICs", 1]) + + # Iterate through a set of test cases to simulate multiple DELETE queries + for i in range(0, 20): + # Perform a "DELETE" query each iteration + for j in range(i): + db_client.query("DELETE") + + # Observe the current state of NMNEs from the NICs of both the database and web servers + state = sim.describe_state() + db_nic_obs = db_server_nic_obs.observe(state)["NMNE"] + web_nic_obs = web_server_nic_obs.observe(state)["NMNE"] + + # Define expected NMNE values based on the iteration count + if i > 10: + expected_nmne = 3 # High level of detected MNEs after 10 iterations + elif i > 5: + expected_nmne = 2 # Moderate level after more than 5 iterations + elif i > 0: + expected_nmne = 1 # Low level detected after just starting + else: + expected_nmne = 0 # No MNEs detected + + # Assert that the observed NMNEs match the expected values for both NICs + assert web_nic_obs["outbound"] == expected_nmne + assert db_nic_obs["inbound"] == expected_nmne + uc2_network.apply_timestep(timestep=0) diff --git a/tests/integration_tests/test_simulation/__init__.py b/tests/integration_tests/test_simulation/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/test_simulation/test_request_response.py b/tests/integration_tests/test_simulation/test_request_response.py new file mode 100644 index 00000000..aee5c816 --- /dev/null +++ b/tests/integration_tests/test_simulation/test_request_response.py @@ -0,0 +1,160 @@ +# some test cases: +# 0. test that sending a request to a valid target results in a success +# 1. test that sending a request to a component that doesn't exist results in a failure +# 2. test that sending a request to a software on a turned-off component results in a failure +# 3. test every implemented action under several different situation, some of which should lead to a success and some to a failure. + +import pytest + +from primaite.interface.request import RequestResponse +from primaite.simulator.network.hardware.node_operating_state import NodeOperatingState +from primaite.simulator.network.hardware.nodes.host.host_node import HostNode +from primaite.simulator.network.hardware.nodes.network.router import ACLAction, Router +from primaite.simulator.network.transmission.transport_layer import Port +from tests.conftest import TestApplication, TestService + + +def test_successful_application_requests(example_network): + net = example_network + + client_1 = net.get_node_by_hostname("client_1") + client_1.software_manager.install(TestApplication) + client_1.software_manager.software.get("TestApplication").run() + + resp_1 = net.apply_request(["node", "client_1", "application", "TestApplication", "scan"]) + assert resp_1 == RequestResponse(status="success", data={}) + resp_2 = net.apply_request(["node", "client_1", "application", "TestApplication", "patch"]) + assert resp_2 == RequestResponse(status="success", data={}) + resp_3 = net.apply_request(["node", "client_1", "application", "TestApplication", "compromise"]) + assert resp_3 == RequestResponse(status="success", data={}) + + +def test_successful_service_requests(example_network): + net = example_network + server_1 = net.get_node_by_hostname("server_1") + server_1.software_manager.install(TestService) + + # Careful: the order here is important, for example we cannot run "stop" unless we run "start" first + for verb in [ + "disable", + "enable", + "start", + "stop", + "start", + "restart", + "pause", + "resume", + "compromise", + "scan", + "patch", + ]: + resp_1 = net.apply_request(["node", "server_1", "service", "TestService", verb]) + assert resp_1 == RequestResponse(status="success", data={}) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + server_1.apply_timestep(timestep=1) + # lazily apply timestep 7 times to make absolutely sure any time-based things like restart have a chance to finish + + +def test_non_existent_requests(example_network): + net = example_network + resp_1 = net.apply_request(["fake"]) + assert resp_1.status == "unreachable" + resp_2 = net.apply_request(["network", "node", "client_39", "application", "WebBrowser", "execute"]) + assert resp_2.status == "unreachable" + + +@pytest.mark.parametrize( + "node_request", + [ + ["node", "client_1", "file_system", "folder", "root", "scan"], + ["node", "client_1", "os", "scan"], + ["node", "client_1", "service", "DNSClient", "stop"], + ["node", "client_1", "application", "WebBrowser", "scan"], + ["node", "client_1", "network_interface", 1, "disable"], + ], +) +def test_request_fails_if_node_off(example_network, node_request): + """Test that requests succeed when the node is on, and fail if the node is off.""" + net = example_network + client_1: HostNode = net.get_node_by_hostname("client_1") + client_1.shut_down_duration = 0 + + assert client_1.operating_state == NodeOperatingState.ON + resp_1 = net.apply_request(node_request) + assert resp_1.status == "success" + + client_1.power_off() + assert client_1.operating_state == NodeOperatingState.OFF + resp_2 = net.apply_request(node_request) + assert resp_2.status == "failure" + + +class TestDataManipulationGreenRequests: + def test_node_off(self, uc2_network): + """Test that green requests succeed when the node is on and fail if the node is off.""" + net = uc2_network + + client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) + client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) + client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) + client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) + assert client_1_browser_execute.status == "success" + assert client_1_db_client_execute.status == "success" + assert client_2_browser_execute.status == "success" + assert client_2_db_client_execute.status == "success" + + client_1 = net.get_node_by_hostname("client_1") + client_2 = net.get_node_by_hostname("client_2") + + client_1.shut_down_duration = 0 + client_1.power_off() + client_2.shut_down_duration = 0 + client_2.power_off() + + client_1_browser_execute_off = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) + client_1_db_client_execute_off = net.apply_request( + ["node", "client_1", "application", "DatabaseClient", "execute"] + ) + client_2_browser_execute_off = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) + client_2_db_client_execute_off = net.apply_request( + ["node", "client_2", "application", "DatabaseClient", "execute"] + ) + assert client_1_browser_execute_off.status == "failure" + assert client_1_db_client_execute_off.status == "failure" + assert client_2_browser_execute_off.status == "failure" + assert client_2_db_client_execute_off.status == "failure" + + def test_acl_block(self, uc2_network): + """Test that green requests succeed when not blocked by ACLs but fail when blocked.""" + net = uc2_network + + router: Router = net.get_node_by_hostname("router_1") + client_1: HostNode = net.get_node_by_hostname("client_1") + client_2: HostNode = net.get_node_by_hostname("client_2") + + client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) + client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) + assert client_1_browser_execute.status == "success" + assert client_2_browser_execute.status == "success" + + router.acl.add_rule(ACLAction.DENY, src_port=Port.HTTP, dst_port=Port.HTTP, position=3) + client_1_browser_execute = net.apply_request(["node", "client_1", "application", "WebBrowser", "execute"]) + client_2_browser_execute = net.apply_request(["node", "client_2", "application", "WebBrowser", "execute"]) + assert client_1_browser_execute.status == "failure" + assert client_2_browser_execute.status == "failure" + + client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) + client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) + assert client_1_db_client_execute.status == "success" + assert client_2_db_client_execute.status == "success" + + router.acl.add_rule(ACLAction.DENY, src_port=Port.POSTGRES_SERVER, dst_port=Port.POSTGRES_SERVER) + client_1_db_client_execute = net.apply_request(["node", "client_1", "application", "DatabaseClient", "execute"]) + client_2_db_client_execute = net.apply_request(["node", "client_2", "application", "DatabaseClient", "execute"]) + assert client_1_db_client_execute.status == "failure" + assert client_2_db_client_execute.status == "failure" diff --git a/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py new file mode 100644 index 00000000..c556cfad --- /dev/null +++ b/tests/unit_tests/_primaite/_game/_agent/test_probabilistic_agent.py @@ -0,0 +1,85 @@ +from primaite.game.agent.actions import ActionManager +from primaite.game.agent.observations.observation_manager import ObservationManager +from primaite.game.agent.observations.observations import ICSObservation +from primaite.game.agent.rewards import RewardFunction +from primaite.game.agent.scripted_agents.probabilistic_agent import ProbabilisticAgent + + +def test_probabilistic_agent(): + """ + Check that the probabilistic agent selects actions with approximately the right probabilities. + + Using a binomial probability calculator (https://www.wolframalpha.com/input/?i=binomial+distribution+calculator), + we can generate some lower and upper bounds of how many times we expect the agent to take each action. These values + were chosen to guarantee a less than 1 in a million chance of the test failing due to unlucky random number + generation. + """ + N_TRIALS = 10_000 + P_DO_NOTHING = 0.1 + P_NODE_APPLICATION_EXECUTE = 0.3 + P_NODE_FILE_DELETE = 0.6 + MIN_DO_NOTHING = 850 + MAX_DO_NOTHING = 1150 + MIN_NODE_APPLICATION_EXECUTE = 2800 + MAX_NODE_APPLICATION_EXECUTE = 3200 + MIN_NODE_FILE_DELETE = 5750 + MAX_NODE_FILE_DELETE = 6250 + + action_space = ActionManager( + actions=[ + {"type": "DONOTHING"}, + {"type": "NODE_APPLICATION_EXECUTE"}, + {"type": "NODE_FILE_DELETE"}, + ], + nodes=[ + { + "node_name": "client_1", + "applications": [{"application_name": "WebBrowser"}], + "folders": [{"folder_name": "downloads", "files": [{"file_name": "cat.png"}]}], + }, + ], + max_folders_per_node=2, + max_files_per_folder=2, + max_services_per_node=2, + max_applications_per_node=2, + max_nics_per_node=2, + max_acl_rules=10, + protocols=["TCP", "UDP", "ICMP"], + ports=["HTTP", "DNS", "ARP"], + act_map={ + 0: {"action": "DONOTHING", "options": {}}, + 1: {"action": "NODE_APPLICATION_EXECUTE", "options": {"node_id": 0, "application_id": 0}}, + 2: {"action": "NODE_FILE_DELETE", "options": {"node_id": 0, "folder_id": 0, "file_id": 0}}, + }, + ) + observation_space = ObservationManager(ICSObservation()) + reward_function = RewardFunction() + + pa = ProbabilisticAgent( + agent_name="test_agent", + action_space=action_space, + observation_space=observation_space, + reward_function=reward_function, + settings={ + "action_probabilities": {0: P_DO_NOTHING, 1: P_NODE_APPLICATION_EXECUTE, 2: P_NODE_FILE_DELETE}, + "random_seed": 120, + }, + ) + + do_nothing_count = 0 + node_application_execute_count = 0 + node_file_delete_count = 0 + for _ in range(N_TRIALS): + a = pa.get_action(0) + if a == ("DONOTHING", {}): + do_nothing_count += 1 + elif a == ("NODE_APPLICATION_EXECUTE", {"node_id": 0, "application_id": 0}): + node_application_execute_count += 1 + elif a == ("NODE_FILE_DELETE", {"node_id": 0, "folder_id": 0, "file_id": 0}): + node_file_delete_count += 1 + else: + raise AssertionError("Probabilistic agent produced an unexpected action.") + + assert MIN_DO_NOTHING < do_nothing_count < MAX_DO_NOTHING + assert MIN_NODE_APPLICATION_EXECUTE < node_application_execute_count < MAX_NODE_APPLICATION_EXECUTE + assert MIN_NODE_FILE_DELETE < node_file_delete_count < MAX_NODE_FILE_DELETE diff --git a/tests/unit_tests/_primaite/_interface/__init__.py b/tests/unit_tests/_primaite/_interface/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_interface/test_request.py b/tests/unit_tests/_primaite/_interface/test_request.py new file mode 100644 index 00000000..5c65b572 --- /dev/null +++ b/tests/unit_tests/_primaite/_interface/test_request.py @@ -0,0 +1,32 @@ +import pytest +from pydantic import ValidationError + +from primaite.interface.request import RequestResponse + + +def test_creating_response_object(): + """Test that we can create a response object with given parameters.""" + r1 = RequestResponse(status="success", data={"test_data": 1, "other_data": 2}) + r2 = RequestResponse(status="unreachable") + r3 = RequestResponse(data={"test_data": "is_good"}) + r4 = RequestResponse() + assert isinstance(r1, RequestResponse) + assert isinstance(r2, RequestResponse) + assert isinstance(r3, RequestResponse) + assert isinstance(r4, RequestResponse) + + +def test_creating_response_from_boolean(): + """Test that we can build a response with a single boolean.""" + r1 = RequestResponse.from_bool(True) + assert r1.status == "success" + + r2 = RequestResponse.from_bool(False) + assert r2.status == "failure" + + with pytest.raises(ValidationError): + r3 = RequestResponse.from_bool(1) + with pytest.raises(ValidationError): + r4 = RequestResponse.from_bool("good") + with pytest.raises(ValidationError): + r5 = RequestResponse.from_bool({"data": True}) diff --git a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py index 01ad3871..786fe851 100644 --- a/tests/unit_tests/_primaite/_simulator/_domain/test_account.py +++ b/tests/unit_tests/_primaite/_simulator/_domain/test_account.py @@ -7,26 +7,11 @@ from primaite.simulator.domain.account import Account, AccountType @pytest.fixture(scope="function") def account() -> Account: acct = Account(username="Jake", password="totally_hashed_password", account_type=AccountType.USER) - acct.set_original_state() return acct def test_original_state(account): """Test the original state - see if it resets properly""" - account.log_on() - account.log_off() - account.disable() - - state = account.describe_state() - assert state["num_logons"] is 1 - assert state["num_logoffs"] is 1 - assert state["num_group_changes"] is 0 - assert state["username"] is "Jake" - assert state["password"] is "totally_hashed_password" - assert state["account_type"] is AccountType.USER.value - assert state["enabled"] is False - - account.reset_component_for_episode(episode=1) state = account.describe_state() assert state["num_logons"] is 0 assert state["num_logoffs"] is 0 @@ -39,13 +24,7 @@ def test_original_state(account): account.log_on() account.log_off() account.disable() - account.set_original_state() - account.log_on() - state = account.describe_state() - assert state["num_logons"] is 2 - - account.reset_component_for_episode(episode=2) state = account.describe_state() assert state["num_logons"] is 1 assert state["num_logoffs"] is 1 diff --git a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py index 9366d173..05824834 100644 --- a/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py +++ b/tests/unit_tests/_primaite/_simulator/_file_system/test_file_system.py @@ -1,7 +1,9 @@ import pytest +from primaite.simulator.file_system.file import File from primaite.simulator.file_system.file_system import FileSystem from primaite.simulator.file_system.file_type import FileType +from primaite.simulator.file_system.folder import Folder def test_create_folder_and_file(file_system): @@ -14,8 +16,15 @@ def test_create_folder_and_file(file_system): assert len(file_system.get_folder("test_folder").files) == 1 + assert file_system.num_file_creations == 1 + assert file_system.get_folder("test_folder").get_file("test_file.txt") + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) @@ -23,24 +32,37 @@ def test_create_file_no_folder(file_system): """Tests that creating a file without a folder creates a folder and sets that as the file's parent.""" file = file_system.create_file(file_name="test_file.txt", size=10) assert len(file_system.folders) is 1 + assert file_system.num_file_creations == 1 assert file_system.get_folder("root").get_file("test_file.txt") == file assert file_system.get_folder("root").get_file("test_file.txt").file_type == FileType.TXT assert file_system.get_folder("root").get_file("test_file.txt").size == 10 + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) def test_delete_file(file_system): """Tests that a file can be deleted.""" - file_system.create_file(file_name="test_file.txt") + file = file_system.create_file(file_name="test_file.txt") assert len(file_system.folders) == 1 assert len(file_system.get_folder("root").files) == 1 file_system.delete_file(folder_name="root", file_name="test_file.txt") + assert file.num_access == 1 + assert file_system.num_file_deletions == 1 assert len(file_system.folders) == 1 assert len(file_system.get_folder("root").files) == 0 assert len(file_system.get_folder("root").deleted_files) == 1 + file_system.apply_timestep(0) + + # num file deletions should reset + assert file_system.num_file_deletions == 0 + file_system.show(full=True) @@ -54,6 +76,7 @@ def test_delete_non_existent_file(file_system): # deleting should not change how many files are in folder file_system.delete_file(folder_name="root", file_name="does_not_exist!") + assert file_system.num_file_deletions == 0 # should still only be one folder assert len(file_system.folders) == 1 @@ -96,6 +119,7 @@ def test_create_duplicate_file(file_system): assert len(file_system.folders) is 2 file_system.create_file(file_name="test_file.txt", folder_name="test_folder") + assert file_system.num_file_creations == 1 assert len(file_system.get_folder("test_folder").files) == 1 @@ -103,6 +127,7 @@ def test_create_duplicate_file(file_system): file_system.create_file(file_name="test_file.txt", folder_name="test_folder") assert len(file_system.get_folder("test_folder").files) == 1 + assert file_system.num_file_creations == 1 file_system.show(full=True) @@ -136,13 +161,24 @@ def test_move_file(file_system): assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 0 + assert file_system.num_file_deletions == 0 + assert file_system.num_file_creations == 1 file_system.move_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder") + assert file_system.num_file_deletions == 1 + assert file_system.num_file_creations == 2 + assert file.num_access == 1 assert len(file_system.get_folder("src_folder").files) == 0 assert len(file_system.get_folder("dst_folder").files) == 1 assert file_system.get_file("dst_folder", "test_file.txt").uuid == original_uuid + file_system.apply_timestep(0) + + # num file creations and deletions should reset + assert file_system.num_file_creations == 0 + assert file_system.num_file_deletions == 0 + file_system.show(full=True) @@ -152,17 +188,25 @@ def test_copy_file(file_system): file_system.create_folder(folder_name="dst_folder") file = file_system.create_file(file_name="test_file.txt", size=10, folder_name="src_folder", real=True) + assert file_system.num_file_creations == 1 original_uuid = file.uuid assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 0 file_system.copy_file(src_folder_name="src_folder", src_file_name="test_file.txt", dst_folder_name="dst_folder") + assert file_system.num_file_creations == 2 + assert file.num_access == 1 assert len(file_system.get_folder("src_folder").files) == 1 assert len(file_system.get_folder("dst_folder").files) == 1 assert file_system.get_file("dst_folder", "test_file.txt").uuid != original_uuid + file_system.apply_timestep(0) + + # num file creations should reset + assert file_system.num_file_creations == 0 + file_system.show(full=True) @@ -172,51 +216,23 @@ def test_get_file(file_system): file1: File = file_system.create_file(file_name="test_file.txt", folder_name="test_folder") file2: File = file_system.create_file(file_name="test_file2.txt", folder_name="test_folder") - folder.remove_file(file2) + file_system.delete_file("test_folder", "test_file2.txt") + # file 2 was accessed before being deleted + assert file2.num_access == 1 assert file_system.get_file_by_id(file_uuid=file1.uuid, folder_uuid=folder.uuid) is not None assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid) is None assert file_system.get_file_by_id(file_uuid=file2.uuid, folder_uuid=folder.uuid, include_deleted=True) is not None assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None + assert file2.num_access == 1 # cannot access deleted file + file_system.delete_folder(folder_name="test_folder") assert file_system.get_file_by_id(file_uuid=file2.uuid, include_deleted=True) is not None file_system.show(full=True) -def test_reset_file_system(file_system): - # file and folder that existed originally - file_system.create_file(file_name="test_file.zip") - file_system.create_folder(folder_name="test_folder") - file_system.set_original_state() - - # create a new file - file_system.create_file(file_name="new_file.txt") - - # create a new folder - file_system.create_folder(folder_name="new_folder") - - # delete the file that existed originally - file_system.delete_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_file(folder_name="root", file_name="test_file.zip") is None - - # delete the folder that existed originally - file_system.delete_folder(folder_name="test_folder") - assert file_system.get_folder(folder_name="test_folder") is None - - # reset - file_system.reset_component_for_episode(episode=1) - - # deleted original file and folder should be back - assert file_system.get_file(folder_name="root", file_name="test_file.zip") - assert file_system.get_folder(folder_name="test_folder") - - # new file and folder should be removed - assert file_system.get_file(folder_name="root", file_name="new_file.txt") is None - assert file_system.get_folder(folder_name="new_folder") is None - - @pytest.mark.skip(reason="Skipping until we tackle serialisation") def test_serialisation(file_system): """Test to check that the object serialisation works correctly.""" diff --git a/tests/unit_tests/_primaite/_simulator/_network/test_container.py b/tests/unit_tests/_primaite/_simulator/_network/test_container.py index 9d424697..f0e386b8 100644 --- a/tests/unit_tests/_primaite/_simulator/_network/test_container.py +++ b/tests/unit_tests/_primaite/_simulator/_network/test_container.py @@ -26,12 +26,11 @@ def filter_keys_nested_item(data, keys): @pytest.fixture(scope="function") def network(example_network) -> Network: - assert len(example_network.routers) is 1 - assert len(example_network.switches) is 2 - assert len(example_network.computers) is 2 - assert len(example_network.servers) is 2 + assert len(example_network.router_nodes) is 1 + assert len(example_network.switch_nodes) is 2 + assert len(example_network.computer_nodes) is 2 + assert len(example_network.server_nodes) is 2 - example_network.set_original_state() example_network.show() return example_network @@ -45,40 +44,6 @@ def test_describe_state(network): assert len(state["links"]) is 6 -def test_reset_network(network): - """ - Test that the network is properly reset. - - TODO: make sure that once implemented - any installed/uninstalled services, processes, apps, - etc are also removed/reinstalled - - """ - state_before = network.describe_state() - - client_1: Computer = network.get_node_by_hostname("client_1") - server_1: Computer = network.get_node_by_hostname("server_1") - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - - client_1.power_off() - assert client_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - server_1.power_off() - assert server_1.operating_state is NodeOperatingState.SHUTTING_DOWN - - assert network.describe_state() != state_before - - network.reset_component_for_episode(episode=1) - - assert client_1.operating_state is NodeOperatingState.ON - assert server_1.operating_state is NodeOperatingState.ON - # don't worry if UUIDs change - a = filter_keys_nested_item(json.dumps(network.describe_state(), sort_keys=True, indent=2), ["uuid"]) - b = filter_keys_nested_item(json.dumps(state_before, sort_keys=True, indent=2), ["uuid"]) - assert a == b - - def test_creating_container(): """Check that we can create a network container""" net = Network() diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py index 2ca67119..6d00886a 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_data_manipulation_bot.py @@ -26,8 +26,8 @@ def test_create_dm_bot(dm_client): data_manipulation_bot: DataManipulationBot = dm_client.software_manager.software.get("DataManipulationBot") assert data_manipulation_bot.name == "DataManipulationBot" - assert data_manipulation_bot.port == Port.POSTGRES_SERVER - assert data_manipulation_bot.protocol == IPProtocol.TCP + assert data_manipulation_bot.port == Port.NONE + assert data_manipulation_bot.protocol == IPProtocol.NONE assert data_manipulation_bot.payload == "DELETE" @@ -70,4 +70,13 @@ def test_dm_bot_perform_data_manipulation_success(dm_bot): dm_bot._perform_data_manipulation(p_of_success=1.0) assert dm_bot.attack_stage in (DataManipulationAttackStage.SUCCEEDED, DataManipulationAttackStage.FAILED) - assert len(dm_bot.connections) + assert len(dm_bot._host_db_client.connections) + + +def test_dm_bot_fails_without_db_client(dm_client): + dm_client.software_manager.uninstall("DatabaseClient") + dm_bot = dm_client.software_manager.software.get("DataManipulationBot") + assert dm_bot._host_db_client is None + dm_bot.attack_stage = DataManipulationAttackStage.PORT_SCAN + dm_bot._perform_data_manipulation(p_of_success=1.0) + assert dm_bot.attack_stage is DataManipulationAttackStage.FAILED diff --git a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py index ccf40c44..4bfd28d0 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py +++ b/tests/unit_tests/_primaite/_simulator/_system/_applications/_red_applications/test_dos_bot.py @@ -19,7 +19,6 @@ def dos_bot() -> DoSBot: dos_bot: DoSBot = computer.software_manager.software.get("DoSBot") dos_bot.configure(target_ip_address=IPv4Address("192.168.0.1")) - dos_bot.set_original_state() return dos_bot @@ -28,35 +27,6 @@ def test_dos_bot_creation(dos_bot): assert dos_bot is not None -def test_dos_bot_reset(dos_bot): - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - - # should reset the relevant items - dos_bot.reset_component_for_episode(episode=0) - assert dos_bot.target_ip_address == IPv4Address("192.168.0.1") - assert dos_bot.target_port is Port.POSTGRES_SERVER - assert dos_bot.payload is None - assert dos_bot.repeat is False - - dos_bot.configure( - target_ip_address=IPv4Address("192.168.1.1"), target_port=Port.HTTP, payload="payload", repeat=True - ) - dos_bot.set_original_state() - dos_bot.reset_component_for_episode(episode=1) - # should reset to the configured value - assert dos_bot.target_ip_address == IPv4Address("192.168.1.1") - assert dos_bot.target_port is Port.HTTP - assert dos_bot.payload == "payload" - assert dos_bot.repeat is True - - def test_dos_bot_cannot_run_when_node_offline(dos_bot): dos_bot_node: Computer = dos_bot.parent assert dos_bot_node.operating_state is NodeOperatingState.ON diff --git a/tests/unit_tests/_primaite/_simulator/_system/test_software.py b/tests/unit_tests/_primaite/_simulator/_system/test_software.py index e77cd895..6f680012 100644 --- a/tests/unit_tests/_primaite/_simulator/_system/test_software.py +++ b/tests/unit_tests/_primaite/_simulator/_system/test_software.py @@ -2,12 +2,14 @@ from typing import Dict import pytest +from primaite.simulator.network.transmission.network_layer import IPProtocol from primaite.simulator.network.transmission.transport_layer import Port from primaite.simulator.system.core.sys_log import SysLog -from primaite.simulator.system.software import Software, SoftwareHealthState +from primaite.simulator.system.services.service import Service +from primaite.simulator.system.software import IOSoftware, SoftwareHealthState -class TestSoftware(Software): +class TestSoftware(Service): def describe_state(self) -> Dict: pass @@ -15,7 +17,11 @@ class TestSoftware(Software): @pytest.fixture(scope="function") def software(file_system): return TestSoftware( - name="TestSoftware", port=Port.ARP, file_system=file_system, sys_log=SysLog(hostname="test_service") + name="TestSoftware", + port=Port.ARP, + file_system=file_system, + sys_log=SysLog(hostname="test_service"), + protocol=IPProtocol.TCP, )