Merge remote-tracking branch 'origin/dev' into feature/2628-update-benchmarking-script-branched

This commit is contained in:
Czar Echavez
2024-06-01 00:20:21 +01:00
35 changed files with 427 additions and 348 deletions

View File

@@ -26,7 +26,7 @@ jobs:
displayName: 'Install build dependencies'
- script: |
pip install -e .[dev]
pip install -e .[dev,rl]
displayName: 'Install PrimAITE for docs autosummary'
- script: |

View File

@@ -82,12 +82,12 @@ stages:
- script: |
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
python -m pip install $PRIMAITE_WHEEL[dev]
python -m pip install $PRIMAITE_WHEEL[dev,rl]
displayName: 'Install PrimAITE'
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
- script: |
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
forfiles /p dist\ /m *.whl /c "cmd /c python -m pip install @file[dev,rl]"
displayName: 'Install PrimAITE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )

View File

@@ -49,7 +49,7 @@ jobs:
- name: Install PrimAITE for docs autosummary
run: |
set -x
python -m pip install -e .[dev]
python -m pip install -e .[dev,rl]
- name: Run build script for Sphinx pages
env:

View File

@@ -48,7 +48,7 @@ jobs:
- name: Install PrimAITE
run: |
PRIMAITE_WHEEL=$(ls ./dist/primaite*.whl)
python -m pip install $PRIMAITE_WHEEL[dev]
python -m pip install $PRIMAITE_WHEEL[dev,rl]
- name: Perform PrimAITE Setup
run: |

View File

@@ -43,7 +43,7 @@ cd ~\primaite
python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory
.\.venv\Scripts\activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
pip install primaite-3.0.0-py3-none-any.whl[rl]
primaite setup
```
@@ -66,7 +66,7 @@ mkdir ~/primaite
cd ~/primaite
python3 -m venv .venv
source .venv/bin/activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
pip install primaite-3.0.0-py3-none-any.whl[rl]
primaite setup
```
@@ -105,7 +105,7 @@ source venv/bin/activate
#### 5. Install `primaite` with the dev extra into the venv along with all of it's dependencies
```bash
python3 -m pip install -e .[dev]
python3 -m pip install -e .[dev,rl]
```
#### 6. Perform the PrimAITE setup:
@@ -114,6 +114,9 @@ python3 -m pip install -e .[dev]
primaite setup
```
#### Note
*It is possible to install PrimAITE without Ray RLLib, StableBaselines3, or any deep learning libraries by omitting the `rl` flag in the pip install command.*
### Running PrimAITE
Use the provided jupyter notebooks as a starting point to try running PrimAITE. They are automatically copied to your PrimAITE notebook folder when you run `primaite setup`.

View File

@@ -29,6 +29,5 @@ clean:
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile | clean
pip-licenses --format=rst --with-urls --output-file=source/primaite-dependencies.rst
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

Binary file not shown.

Before

Width:  |  Height:  |  Size: 35 KiB

After

Width:  |  Height:  |  Size: 23 KiB

BIN
docs/_static/primAITE_architecture.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 106 KiB

View File

@@ -19,4 +19,3 @@
:recursive:
primaite
tests

View File

@@ -11,66 +11,79 @@ What is PrimAITE?
Overview
^^^^^^^^
The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effective simulation capability for the purposes of training and evaluating AI in a cyber-defensive role. It incorporates the functionality required of a primary-level ARCD environment, which includes:
The ARCD Primary-level AI Training Environment (**PrimAITE**) provides an effective simulation capability for training and evaluating AI in a cyber-defensive role. It incorporates the functionality required of a primary-level ARCD environment:
- The ability to model a relevant platform / system context;
- The ability to model a relevant system context;
- Modelling an adversarial agent that the defensive agent can be trained and evaluated against;
- The ability to model key characteristics of a platform / system by representing connections, IP addresses, ports, operating systems, services and traffic loading on links;
- Modelling background pattern-of-life;
- Operates at machine-speed to enable fast training cycles.
- The ability to model key characteristics of a system by representing hosts, servers, network devices, IP addresses, ports, operating systems, folders / files, applications, services and links;
- Modelling background (green) pattern-of-life;
- Operates at machine-speed to enable fast training cycles via Reinforcement Learning (RL).
Features
^^^^^^^^
PrimAITE incorporates the following features:
- Highly configurable (via YAML files) to provide the means to model a variety of platform / system laydowns and adversarial attack scenarios;
- A Reinforcement Learning (RL) reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure success;
- Provision of logging to support AI performance / effectiveness assessment;
- Uses the concept of Information Exchange Requirements (IERs) to model background pattern of life and adversarial behaviour;
- An Access Control List (ACL) function, mimicking the behaviour of a network firewall, is applied across the model, following standard ACL rule format (e.g. DENY/ALLOW, source IP address, destination IP address, protocol and port);
- Application of traffic to the links of the platform / system laydown adheres to the ACL ruleset;
- Presents both a Gymnasium and Ray RLLib interface to the environment, allowing integration with any compliant defensive agents;
- Allows for the saving and loading of trained defensive agents;
- Stochastic adversarial agent behaviour;
- Full capture of discrete logs relating to agent training or evaluation (system state, agent actions taken, instantaneous and average reward for every step of every episode);
- Distinct control over running a training and / or evaluation session;
- NetworkX provides laydown visualisation capability.
- Architected with a separate Simulation layer and Game layer. This separation of concerns defines a clear path towards transfer learning with environments of differing fidelity;
- Ability to reconfigure an RL reward function based on (a) the ability to counter the modelled adversarial cyber-attack, and (b) the ability to ensure success for green agents;
- Access Control List (ACL) functions for network devices (routers and firewalls), following standard ACL rule format (e.g., DENY / ALLOW, source / destination IP addresses, protocol and port);
- Application of traffic to the links of the system laydown adheres to the ACL rulesets and routing tables contained within each network device;
- Provides RL environments adherent to the Farama Foundation Gymnasium (Previously OpenAI Gym) API, allowing integration with any compliant RL Agent frameworks;
- Provides RL environments adherent to Ray RLlib environment specifications for single-agent and multi-agent scenarios;
- Assessed for compatibility with Stable-Baselines3 (SB3), Ray RLlib, and bespoke agents;
- Persona-based adversarial (Red) agent behaviour; several out-the-box personas are provided, and more can be developed to suit the needs of the task. Stochastic variations in Red agent behaviour are also included as required;
- A robust system logging tool, automatically enabled at the node level and featuring various log levels and terminal output options, enables PrimAITE users to conduct in-depth network simulations;
- A PCAP service is seamlessly integrated within the simulation, automatically capturing and logging frames for both
inbound and outbound traffic at the network interface level. This automatic functionality, combined with the ability
to separate traffic directions, significantly enhances network analysis and troubleshooting capabilities;
- Agent action logs provide a description of every action taken by each agent during the episode. This includes timestep, action, parameters, request and response, for all Blue agent activity, which is aligned with the Track 2 Common Action / Observation Space (CAOS) format. Action logs also details of all scripted / stochastic red / green agent actions;
- Environment ground truth is provided at every timestep, providing a full description of the environments true state;
- Alignment with CAOS provides the ability to transfer agents between CAOS compliant environments.
Architecture
^^^^^^^^^^^^
PrimAITE is a Python application and is therefore Operating System agnostic. The Gymnasium and Ray RLLib frameworks are employed to provide an interface and source for AI agents. Configuration of PrimAITE is achieved via included YAML files which support full control over the platform / system laydown being modelled, background pattern of life, adversarial (red agent) behaviour, and step and episode count. NetworkX based nodes and links host Python classes to present attributes and methods, and hence a more representative platform / system can be modelled within the simulation.
PrimAITE is a Python application and will operate on multiple Operating Systems (Windows, Linux and Mac);
a comprehensive installation and user guide is provided with each release to support its usage.
Configuration of PrimAITE is achieved via included YAML files which support full control over the network / system laydown being modelled, background pattern of life, adversarial (red agent) behaviour, and step and episode count.
A Simulation Controller layer manages the overall running of the simulation, keeping track of all low-level objects.
It is agnostic to the number of agents, their action / observation spaces, and the RL library being used.
It presents a public API providing a method for describing the current state of the simulation, a method that accepts action requests and provides responses, and a method that triggers a timestep advancement.
The Game Layer converts the simulation into a playable game for the agent(s).
it translates between simulation state and Gymnasium.Spaces to pass action / observation data between the agent(s) and the simulation. It is responsible for calculating rewards, managing Multi-Agent RL (MARL) action turns, and via a single agent interface can interact with Blue, Red and Green agents.
Agents can either generate their own scripted behaviour or accept input behaviour from an RL agent.
Finally, a Gymnasium / Ray RLlib Environment Layer forwards requests to the Game Layer as the agent sends them. This layer also manages most of the I/O, such as reading in the configuration files and saving agent logs.
.. image:: ../../_static/primAITE_architecture.png
:width: 500
:align: center
Training & Evaluation Capability
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
PrimAITE provides a training and evaluation capability to AI agents in the context of cyber-attack, via its Gymnasium and RLLib compliant interface. Scenarios can be constructed to reflect platform / system laydowns consisting of any configuration of nodes (e.g. PCs, servers, switches etc.) and network links between them. All nodes can be configured to model services (and their status) and the traffic loading between them over the network links. Traffic loading is broken down into a per service granularity, relating directly to a protocol (e.g. Service A would be configured as a TCP service, and TCP traffic then flows between instances of Service A under the direction of a tailored IER). Highlights of PrimAITEs training and evaluation capability are:
PrimAITE provides a training and evaluation capability to AI agents in the context of cyber-attack, via its Gymnasium / Ray RLlib compliant interface.
Scenarios can be constructed to reflect network / system laydowns consisting of any configuration of nodes (e.g., PCs, servers etc.) and the networking equipment and links between them.
All nodes can be configured to contain applications, services, folders and files (and their status).
Traffic flows between services and applications as directed by an execution definition, with the traffic flow on the network governed by the network equipment (switches, routers and firewalls) and the ACL rules and routing tables they employ.
Highlights of PrimAITEs training and evaluation capability are:
- The scenario is not bound to a representation of any platform, system, or technology;
- Fully configurable (network / system laydown, IERs, node pattern-of-life, ACL, number of episodes, steps per episode) and repeatable to suit the requirements of AI agents;
- Can integrate with any Gymnasium or RLLib compliant AI agent.
Use of PrimAITE default scenarios within ARCD is supported by a “Use Case Profile” tailored to the scenario.
AI Assessment Capability
^^^^^^^^^^^^^^^^^^^^^^^^
PrimAITE includes the capability to support in-depth assessment of cyber defence AI by outputting logs of the environment state and AI behaviour throughout both training and evaluation sessions. These logs include the following data:
- Timestamp;
- Episode and step number;
- Agent identifier;
- Observation space;
- Action taken (by defensive AI);
- Reward value.
Logs are available in CSV format and provide coverage of the above data for every step of every episode.
- Fully configurable (network / system laydown, green pattern-of-life, red personas, reward function, ACL rules for each device, number of episodes / steps, action / observation space) and repeatable to suit the requirements of AI agents;
- Can integrate with any Gymnasium / Ray RLlib compliant AI agent .
PrimAITE provides a number of use cases (network and red/green action configurations) by default which the user is able to extend and modify as required.
What is PrimAITE built with
---------------------------
@@ -109,6 +122,7 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/config
source/environment
source/customising_scenarios
source/varying_config_files
.. toctree::
:caption: Notebooks:
@@ -126,13 +140,3 @@ Head over to the :ref:`getting-started` page to install and setup PrimAITE!
source/request_system
PrimAITE API <source/_autosummary/primaite>
PrimAITE Tests <source/_autosummary/tests>
.. toctree::
:caption: Project Links:
:hidden:
Code <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE>
Issues <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/issues>
Pull Requests <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/pulls>
Discussions <https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/discussions>

View File

@@ -36,11 +36,6 @@ IF EXIST %AUTOSUMMARYDIR% (
RMDIR %AUTOSUMMARYDIR% /s /q
)
REM print the YT licenses
set LICENSEBUILD=pip-licenses --format=rst --with-urls
set DEPS="%cd%\source\primaite-dependencies.rst"
%LICENSEBUILD% --output-file=%DEPS%
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
goto end

View File

@@ -82,7 +82,7 @@ Install PrimAITE
.. code-block:: bash
:caption: Unix
pip install path/to/your/primaite.whl
pip install path/to/your/primaite.whl[rl]
.. code-block:: powershell
:caption: Windows (Powershell)
@@ -107,7 +107,9 @@ Clone & Install PrimAITE for Development
To be able to extend PrimAITE further, or to build wheels manually before install, clone the repository to a location
of your choice:
1. Clone the repository
1. Clone the repository.
For example:
.. code-block:: bash
@@ -133,12 +135,12 @@ of your choice:
.. code-block:: bash
:caption: Unix
pip install -e .[dev]
pip install -e .[dev,rl]
.. code-block:: powershell
:caption: Windows (Powershell)
pip install -e .[dev]
pip install -e .[dev,rl]
To view the complete list of packages installed during PrimAITE installation, go to the dependencies page (:ref:`Dependencies`).

View File

@@ -38,14 +38,11 @@ Glossary
Blue Agent
A defensive agent that protects the network from Red Agent attacks to minimise disruption to green agents and protect data.
Information Exchange Requirement (IER)
Simulates network traffic by sending data from one network node to another via links for a specified amount of time. IERs can be part of green agent behaviour or red agent behaviour. PrimAITE can be configured to apply a penalty for green agents' IERs being blocked and a reward for red agents' IERs being blocked.
Pattern-of-Life (PoL)
PoLs allow agents to change the current hardware, OS, file system, or service statuses of nodes during the course of an episode. For example, a green agent may restart a server node to represent scheduled maintainance. A red agent's Pattern-of-Life can be used to attack nodes by changing their states to CORRUPTED or COMPROMISED.
Reward
The reward is a single number used by the blue agent to understand whether it's performing well or poorly. RL agents change their behaviour in an attempt to increase the expected reward each episode. The reward is generated based on the current states of the environment / :term:`reference environment` and is impacted positively by things like green IERS running successfully and negatively by things like nodes being compromised.
The reward is a single number used by the blue agent to understand whether it's performing well or poorly. RL agents change their behaviour in an attempt to increase the expected reward each episode. The reward is generated based on the current states of the environment and is impacted positively by things like green PoL running successfully and negatively by things like nodes being compromised.
Observation
An observation is a representation of the current state of the environment that is given to the learning agent so it can decide on which action to perform. If the environment is 'fully observable', the observation contains information about every possible aspect of the environment. More commonly, the environment is 'partially observable' which means the learning agent has to make decisions without knowing every detail of the current environment state.
@@ -65,12 +62,6 @@ Glossary
Episode
When an episode starts, the network simulation is reset to an initial state. The agents take actions on each step of the episode until it reaches a terminal state, which usually happens after a predetermined number of steps. After the terminal state is reached, a new episode starts and the RL agent has another opportunity to protect the network.
Reference environment
While the network simulation is unfolding, a parallel simulation takes place which is identical to the main one except that blue and red agent actions are not applied. This reference environment essentially shows what would be happening to the network if there had been no cyberattack or defense. The reference environment is used to calculate rewards.
Transaction
PrimAITE records the decisions of the learning agent by saving its observation, action, and reward at every time step. During each session, this data is saved to disk to allow for full inspection.
Laydown
The laydown is a file which defines the training scenario. It contains the network topology, firewall rules, services, protocols, and details about green and red agent behaviours.
@@ -78,4 +69,4 @@ Glossary
PrimAITE uses the Gymnasium reinforcement learning framework API to create a training environment and interface with RL agents. Gymnasium defines a common way of creating observations, actions, and rewards.
User app home
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>` on linux/darwin and `C:\\Users\\<username>\\primaite\\<version>` on Windows.
PrimAITE supports upgrading software version while retaining user data. The user data directory is where configs, notebooks, and results are stored, this location is `~/primaite<version>/` on linux/darwin and `C:\\Users\\<username>\\primaite<version>` on Windows.

View File

@@ -0,0 +1,37 @@
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| Name | Version | License | Description | URL |
+===================+=========+====================================+=======================================================================================================+==============================================+
| gymnasium | 0.28.1 | MIT License | A standard API for reinforcement learning and a diverse set of reference environments (formerly Gym). | https://farama.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| ipywidgets | 8.1.3 | BSD License | Jupyter interactive widgets | http://jupyter.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| jupyterlab | 3.6.1 | BSD License | JupyterLab computational environment | https://jupyter.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| kaleido | 0.2.1 | MIT | Static image export for web-based visualization libraries with zero dependencies | https://github.com/plotly/Kaleido |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| matplotlib | 3.7.1 | Python Software Foundation License | Python plotting package | https://matplotlib.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| networkx | 3.1 | BSD License | Python package for creating and manipulating graphs and networks | https://networkx.org/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| numpy | 1.23.5 | BSD License | NumPy is the fundamental package for array computing with Python. | https://www.numpy.org |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| platformdirs | 3.5.1 | MIT License | A small Python package for determining appropriate platform-specific dirs, e.g. a "user data dir". | https://github.com/platformdirs/platformdirs |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| plotly | 5.15.0 | MIT License | An open-source, interactive data visualization library for Python | https://plotly.com/python/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| polars | 0.18.4 | MIT License | Blazingly fast DataFrame library | https://www.pola.rs/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| prettytable | 3.8.0 | BSD License (BSD (3 clause)) | A simple Python library for easily displaying tabular data in a visually appealing ASCII table format | https://github.com/jazzband/prettytable |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| pydantic | 2.7.0 | MIT License | Data validation using Python type hints | https://github.com/pydantic/pydantic |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| PyYAML | 6.0 | MIT License | YAML parser and emitter for Python | https://pyyaml.org/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| ray | 2.23.0 | Apache 2.0 | Ray provides a simple, universal API for building distributed applications. | https://github.com/ray-project/ray |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| stable-baselines3 | 2.1.0 | MIT | Pytorch version of Stable Baselines, implementations of reinforcement learning algorithms. | https://github.com/DLR-RM/stable-baselines3 |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| tensorflow | 2.12.0 | Apache Software License | TensorFlow is an open source machine learning framework for everyone. | https://www.tensorflow.org/ |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+
| typer | 0.9.0 | MIT License | Typer, build great CLIs. Easy to code. Based on Python type hints. | https://github.com/tiangolo/typer |
+-------------------+---------+------------------------------------+-------------------------------------------------------------------------------------------------------+----------------------------------------------+

View File

@@ -0,0 +1,49 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
Defining variations in the config files
================
PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.
When using a fixed scenario, a single yaml config file is used. However, to use episode schedules, PrimAITE uses a directory with several config files that work together.
Defining variations in the config file.
Base scenario
*************
The base scenario is essentially the same as a fixed YAML configuration, but it can contain placeholders that are populated with episode-specific data at runtime. The base scenario contains any network, agent, or settings that remain fixed for the entire training/evaluation session.
The placeholders are defined as YAML Aliases and they are denoted by an asterisk (*placeholder).
Variations
**********
For each variation that could be used in a placeholder, there is a separate yaml file that contains the data that should populate the placeholder.
The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand ``&anchor``.
Learn more about YAML Aliases and Anchors here.
Schedule
********
Users must define which combination of scenario variations should be loaded in each episode. This takes the form of a YAML file with a relative path to the base scenario and a list of paths to be loaded in during each episode.
It takes the following format:
.. code-block:: yaml
base_scenario: base.yaml
schedule:
0: # list of variations to load in at episode 0 (before the first call to env.reset() happens)
- laydown_1.yaml
- attack_1.yaml
1: # list of variations to load in at episode 1 (after the first env.reset() call)
- laydown_2.yaml
- attack_2.yaml
For more information please refer to the ``Using Episode Schedules`` notebook in either :ref:`Executed Notebooks` or run the notebook interactively in ``notebooks/example_notebooks/``.
For further information around notebooks in general refer to the :ref:`Example Jupyter Notebooks`.

View File

@@ -36,11 +36,8 @@ dependencies = [
"polars==0.20.30",
"prettytable==3.8.0",
"PyYAML==6.0",
"stable-baselines3[extra]==2.1.0",
"tensorflow==2.12.0",
"typer[all]==0.9.0",
"pydantic==2.7.0",
"ray[rllib] >= 2.9, < 3",
"ipywidgets",
"deepdiff"
]
@@ -56,6 +53,11 @@ license-files = ["LICENSE"]
[project.optional-dependencies]
rl = [
"ray[rllib] >= 2.20.0, < 3",
"tensorflow==2.12.0",
"stable-baselines3[extra]==2.1.0",
]
dev = [
"build==0.10.0",
"flake8==6.0.0",

View File

@@ -14,7 +14,7 @@ if TYPE_CHECKING:
pass
class AgentActionHistoryItem(BaseModel):
class AgentHistoryItem(BaseModel):
"""One entry of an agent's action log - what the agent did and how the simulator responded in 1 step."""
timestep: int
@@ -32,6 +32,8 @@ class AgentActionHistoryItem(BaseModel):
response: RequestResponse
"""The response sent back by the simulator for this action."""
reward: Optional[float] = None
class AgentStartSettings(BaseModel):
"""Configuration values for when an agent starts performing actions."""
@@ -110,7 +112,7 @@ class AbstractAgent(ABC):
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
self.action_history: List[AgentActionHistoryItem] = []
self.history: List[AgentHistoryItem] = []
def update_observation(self, state: Dict) -> ObsType:
"""
@@ -130,7 +132,7 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.update(state=state, last_action_response=self.action_history[-1])
return self.reward_function.update(state=state, last_action_response=self.history[-1])
@abstractmethod
def get_action(self, obs: ObsType, timestep: int = 0) -> Tuple[str, Dict]:
@@ -161,12 +163,16 @@ class AbstractAgent(ABC):
self, timestep: int, action: str, parameters: Dict[str, Any], request: RequestFormat, response: RequestResponse
) -> None:
"""Process the response from the most recent action."""
self.action_history.append(
AgentActionHistoryItem(
self.history.append(
AgentHistoryItem(
timestep=timestep, action=action, parameters=parameters, request=request, response=response
)
)
def save_reward_to_history(self) -> None:
"""Update the most recent history item with the reward value."""
self.history[-1].reward = self.reward_function.current_reward
class AbstractScriptedAgent(AbstractAgent):
"""Base class for actors which generate their own behaviour."""

View File

@@ -34,7 +34,7 @@ from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
if TYPE_CHECKING:
from primaite.game.agent.interface import AgentActionHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
_LOGGER = getLogger(__name__)
WhereType = Optional[Iterable[Union[str, int]]]
@@ -44,7 +44,7 @@ class AbstractReward:
"""Base class for reward function components."""
@abstractmethod
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -64,7 +64,7 @@ class AbstractReward:
class DummyReward(AbstractReward):
"""Dummy reward function component which always returns 0."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state."""
return 0.0
@@ -104,7 +104,7 @@ class DatabaseFileIntegrity(AbstractReward):
file_name,
]
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -159,7 +159,7 @@ class WebServer404Penalty(AbstractReward):
"""
self.location_in_state = ["network", "nodes", node_hostname, "services", service_name]
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the reward for the current state.
:param state: The current state of the simulation.
@@ -213,7 +213,7 @@ class WebpageUnavailablePenalty(AbstractReward):
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "WebBrowser"]
self._last_request_failed: bool = False
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
Calculate the reward based on current simulation state, and the recent agent action.
@@ -273,7 +273,7 @@ class GreenAdminDatabaseUnreachablePenalty(AbstractReward):
self.location_in_state: List[str] = ["network", "nodes", node_hostname, "applications", "DatabaseClient"]
self._last_request_failed: bool = False
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""
Calculate the reward based on current simulation state, and the recent agent action.
@@ -343,7 +343,7 @@ class SharedReward(AbstractReward):
self.callback: Callable[[str], float] = default_callback
"""Method that retrieves an agent's current reward given the agent's name."""
def calculate(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def calculate(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Simply access the other agent's reward and return it."""
return self.callback(self.agent_name)
@@ -389,7 +389,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def update(self, state: Dict, last_action_response: "AgentActionHistoryItem") -> float:
def update(self, state: Dict, last_action_response: "AgentHistoryItem") -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.

View File

@@ -160,6 +160,7 @@ class PrimaiteGame:
agent = self.agents[agent_name]
if self.step_counter > 0: # can't get reward before first action
agent.update_reward(state=state)
agent.save_reward_to_history()
agent.update_observation(state=state) # order of this doesn't matter so just use reward order
agent.reward_function.total_reward += agent.reward_function.current_reward

View File

@@ -22,7 +22,7 @@
"# Imports\n",
"\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"from primaite.game.agent.interface import AgentHistoryItem\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"import yaml\n",
"from pprint import pprint"
@@ -63,7 +63,7 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",

View File

@@ -59,7 +59,7 @@
"\n",
"At the start of every episode, the red agent randomly chooses either client 1 or client 2 to login to. It waits a bit then sends a DELETE query to the database from its chosen client. If the delete is successful, the database file is flagged as compromised to signal that data is not available.\n",
"\n",
"[<img src=\"_package_data/uc2_attack.png\" width=\"500\"/>](_package_data/uc2_attack.png)\n",
"![uc2_attack](./_package_data/uc2_attack.png)\n",
"\n",
"_(click image to enlarge)_"
]
@@ -180,15 +180,15 @@
"| link_id | endpoint_a | endpoint_b |\n",
"|---------|------------------|-------------------|\n",
"| 1 | router_1 | switch_1 |\n",
"| 1 | router_1 | switch_2 |\n",
"| 1 | switch_1 | domain_controller |\n",
"| 1 | switch_1 | web_server |\n",
"| 1 | switch_1 | database_server |\n",
"| 1 | switch_1 | backup_server |\n",
"| 1 | switch_1 | security_suite |\n",
"| 1 | switch_2 | client_1 |\n",
"| 1 | switch_2 | client_2 |\n",
"| 1 | switch_2 | security_suite |\n",
"| 2 | router_1 | switch_2 |\n",
"| 3 | switch_1 | domain_controller |\n",
"| 4 | switch_1 | web_server |\n",
"| 5 | switch_1 | database_server |\n",
"| 6 | switch_1 | backup_server |\n",
"| 7 | switch_1 | security_suite |\n",
"| 8 | switch_2 | client_1 |\n",
"| 9 | switch_2 | client_2 |\n",
"| 10 | switch_2 | security_suite |\n",
"\n",
"\n",
"The ACL rules in the observation space appear in the same order that they do in the actual ACL. Though, only the first 10 rules are shown, there are default rules lower down that cannot be changed by the agent. The extra rules just allow the network to function normally, by allowing pings, ARP traffic, etc.\n",
@@ -392,7 +392,7 @@
"# Imports\n",
"from primaite.config.load import data_manipulation_config_path\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite.game.agent.interface import AgentActionHistoryItem\n",
"from primaite.game.agent.interface import AgentHistoryItem\n",
"import yaml\n",
"from pprint import pprint\n"
]
@@ -401,7 +401,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Instantiate the environment. We also disable the agent observation flattening.\n",
"Instantiate the environment. \n",
"We will also disable the agent observation flattening.\n",
"\n",
"This cell will print the observation when the network is healthy. You should be able to verify Node file and service statuses against the description above."
]
@@ -444,7 +445,7 @@
"source": [
"def friendly_output_red_action(info):\n",
" # parse the info dict form step output and write out what the red agent is doing\n",
" red_info : AgentActionHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_info : AgentHistoryItem = info['agent_actions']['data_manipulation_attacker']\n",
" red_action = red_info.action\n",
" if red_action == 'DONOTHING':\n",
" red_str = 'DO NOTHING'\n",
@@ -705,7 +706,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.8"
}
},
"nbformat": 4,

View File

@@ -25,13 +25,13 @@
"from primaite.game.game import PrimaiteGame\n",
"import yaml\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"\n",
"import ray\n",
"from ray import air, tune\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
"from primaite.session.environment import PrimaiteRayMARLEnv\n",
"from primaite.session.ray_envs import PrimaiteRayMARLEnv\n",
"\n",
"# If you get an error saying this config file doesn't exist, you may need to run `primaite setup` in your command line\n",
"# to copy the files to your user data path.\n",
@@ -60,8 +60,8 @@
" policies={'defender_1','defender_2'}, # These names are the same as the agents defined in the example config.\n",
" policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,\n",
" )\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)#, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .environment(env=PrimaiteRayMARLEnv, env_config=cfg)\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
" )\n"
]

View File

@@ -18,8 +18,7 @@
"import yaml\n",
"from primaite.config.load import data_manipulation_config_path\n",
"\n",
"from primaite.session.environment import PrimaiteRayEnv\n",
"from ray.rllib.algorithms import ppo\n",
"from primaite.session.ray_envs import PrimaiteRayEnv\n",
"from ray import air, tune\n",
"import ray\n",
"from ray.rllib.algorithms.ppo import PPOConfig\n",
@@ -52,8 +51,8 @@
"\n",
"config = (\n",
" PPOConfig()\n",
" .environment(env=PrimaiteRayEnv, env_config=env_config, disable_env_checking=True)\n",
" .rollouts(num_rollout_workers=0)\n",
" .environment(env=PrimaiteRayEnv, env_config=env_config)\n",
" .env_runners(num_env_runners=0)\n",
" .training(train_batch_size=128)\n",
")\n"
]
@@ -74,7 +73,7 @@
"tune.Tuner(\n",
" \"PPO\",\n",
" run_config=air.RunConfig(\n",
" stop={\"timesteps_total\": 5 * 128}\n",
" stop={\"timesteps_total\": 512}\n",
" ),\n",
" param_space=config\n",
").fit()\n"
@@ -97,7 +96,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.11"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -43,7 +43,10 @@
"outputs": [],
"source": [
"with open(data_manipulation_config_path(), 'r') as f:\n",
" cfg = yaml.safe_load(f)"
" cfg = yaml.safe_load(f)\n",
"for agent in cfg['agents']:\n",
" if agent['ref'] == 'defender':\n",
" agent['agent_settings']['flatten_obs']=True"
]
},
{
@@ -177,7 +180,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.10"
"version": "3.10.12"
}
},
"nbformat": 4,

View File

@@ -13,50 +13,6 @@
"directory with several config files that work together."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining variations in the config file.\n",
"\n",
"### Base scenario\n",
"The base scenario is essentially the same as a fixed YAML configuration, but it can contain placeholders that are \n",
"populated with episode-specific data at runtime. The base scenario contains any network, agent, or settings that\n",
"remain fixed for the entire training/evaluation session.\n",
"\n",
"The placeholders are defined as YAML Aliases and they are denoted by an asterisk (`*placeholder`).\n",
"\n",
"### Variations\n",
"For each variation that could be used in a placeholder, there is a separate yaml file that contains the data that should populate the placeholder.\n",
"\n",
"The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand (`&anchor`).\n",
"\n",
"[Learn more about YAML Aliases and Anchors here.](https://www.educative.io/blog/advanced-yaml-syntax-cheatsheet#:~:text=YAML%20Anchors%20and%20Alias)\n",
"\n",
"### Schedule\n",
"Users must define which combination of scenario variations should be loaded in each episode. This takes the form of a\n",
"YAML file with a relative path to the base scenario and a list of paths to be loaded in during each episode.\n",
"\n",
"It takes the following format:\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"```yaml\n",
"base_scenario: base.yaml\n",
"schedule:\n",
" 0: # list of variations to load in at episode 0 (before the first call to env.reset() happens)\n",
" - laydown_1.yaml\n",
" - attack_1.yaml\n",
" 1: # list of variations to load in at episode 1 (after the first env.reset() call)\n",
" - laydown_2.yaml\n",
" - attack_2.yaml\n",
"```\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
@@ -298,8 +254,8 @@
"table = PrettyTable()\n",
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
"for i in range(21):\n",
" green_action = env.game.agents['green_A'].action_history[i].action\n",
" red_action = env.game.agents['red_A'].action_history[i].action\n",
" green_action = env.game.agents['green_A'].history[i].action\n",
" red_action = env.game.agents['red_A'].history[i].action\n",
" table.add_row([i, green_action, red_action])\n",
"print(table)"
]
@@ -329,8 +285,8 @@
"table = PrettyTable()\n",
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
"for i in range(21):\n",
" green_action = env.game.agents['green_B'].action_history[i].action\n",
" red_action = env.game.agents['red_B'].action_history[i].action\n",
" green_action = env.game.agents['green_B'].history[i].action\n",
" red_action = env.game.agents['red_B'].history[i].action\n",
" table.add_row([i, green_action, red_action])\n",
"print(table)"
]

View File

@@ -4,8 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple multi-processing demo using SubprocVecEnv from SB3\n",
"Based on a code example provided by Rachael Proctor."
"## Simple multi-processing demo using SubprocVecEnv from SB3"
]
},
{

View File

@@ -4,7 +4,6 @@ from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
import gymnasium
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite import getLogger
from primaite.game.agent.interface import ProxyAgent
@@ -12,6 +11,7 @@ from primaite.game.game import PrimaiteGame
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.system.core.packet_capture import PacketCapture
_LOGGER = getLogger(__name__)
@@ -63,7 +63,7 @@ class PrimaiteGymEnv(gymnasium.Env):
terminated = False
truncated = self.game.calculate_truncated()
info = {
"agent_actions": {name: agent.action_history[-1] for name, agent in self.game.agents.items()}
"agent_actions": {name: agent.history[-1] for name, agent in self.game.agents.items()}
} # tell us what all the agents did for convenience.
if self.game.save_step_metadata:
self._write_step_metadata_json(step, action, state, reward)
@@ -94,9 +94,10 @@ class PrimaiteGymEnv(gymnasium.Env):
self.average_reward_per_episode[self.episode_counter] = self.agent.reward_function.total_reward
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
PacketCapture.clear()
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
@@ -130,166 +131,5 @@ class PrimaiteGymEnv(gymnasium.Env):
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
def close(self):
"""Close the simulation."""
self.env.close()
@property
def game(self) -> PrimaiteGame:
"""Pass through game from env."""
return self.env.game
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict
"""
self.episode_counter: int = 0
"""Current episode number."""
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
super().__init__()
@property
def agents(self) -> Dict[str, ProxyAgent]:
"""Grab a fresh reference to the agents from this episode's game object."""
return {name: self.game.rl_agents[name] for name in self._agent_ids}
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
def step(
self, actions: Dict[str, ActType]
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
:type actions: Dict[str, ActType]
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
identifier.
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
"""
step = self.game.step_counter
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.pre_timestep()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
# 3. Get next observations
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {name: {} for name, _ in self.agents.items()}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(step, actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{step}.json"
data = {
"episode": self.episode_counter,
"step": step,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
"state": state,
}
with open(path, "w") as file:
json.dump(data, file)
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)

View File

@@ -87,7 +87,7 @@ class PrimaiteIO:
"""Return the path where agent actions will be saved."""
return self.session_path / "agent_actions" / f"episode_{episode}.json"
def write_agent_actions(self, agent_actions: Dict[str, List], episode: int) -> None:
def write_agent_log(self, agent_actions: Dict[str, List], episode: int) -> None:
"""Take the contents of the agent action log and write it to a file.
:param episode: Episode number

View File

@@ -0,0 +1,177 @@
import json
from typing import Dict, SupportsFloat, Tuple
import gymnasium
from gymnasium.core import ActType, ObsType
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.environment import _LOGGER, PrimaiteGymEnv
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.system.core.packet_capture import PacketCapture
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration. It must contain a single key, `game`
which is the PrimaiteGame instance.
:type env_config: Dict
"""
self.episode_counter: int = 0
"""Current episode number."""
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.terminateds = set()
self.truncateds = set()
self.observation_space = gymnasium.spaces.Dict(
{
name: gymnasium.spaces.flatten_space(agent.observation_manager.space)
for name, agent in self.agents.items()
}
)
self.action_space = gymnasium.spaces.Dict(
{name: agent.action_manager.space for name, agent in self.agents.items()}
)
self._obs_space_in_preferred_format = True
self._action_space_in_preferred_format = True
super().__init__()
@property
def agents(self) -> Dict[str, ProxyAgent]:
"""Grab a fresh reference to the agents from this episode's game object."""
return {name: self.game.rl_agents[name] for name in self._agent_ids}
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
rewards = {name: agent.reward_function.total_reward for name, agent in self.agents.items()}
_LOGGER.info(f"Resetting environment, episode {self.episode_counter}, " f"avg. reward: {rewards}")
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
self.episode_counter += 1
PacketCapture.clear()
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
def step(
self, actions: Dict[str, ActType]
) -> Tuple[Dict[str, ObsType], Dict[str, SupportsFloat], Dict[str, bool], Dict[str, bool], Dict]:
"""Perform a step in the environment. Adherent to Ray MultiAgentEnv step API.
:param actions: Dict of actions. The key is agent identifier and the value is a gymnasium action instance.
:type actions: Dict[str, ActType]
:return: Observations, rewards, terminateds, truncateds, and info. Each one is a dictionary keyed by agent
identifier.
:rtype: Tuple[Dict[str,ObsType], Dict[str, SupportsFloat], Dict[str,bool], Dict[str,bool], Dict]
"""
step = self.game.step_counter
# 1. Perform actions
for agent_name, action in actions.items():
self.agents[agent_name].store_action(action)
self.game.pre_timestep()
self.game.apply_agent_actions()
# 2. Advance timestep
self.game.advance_timestep()
# 3. Get next observations
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()
# 4. Get rewards
rewards = {name: agent.reward_function.current_reward for name, agent in self.agents.items()}
_LOGGER.info(f"step: {self.game.step_counter}, Rewards: {rewards}")
terminateds = {name: False for name, _ in self.agents.items()}
truncateds = {name: self.game.calculate_truncated() for name, _ in self.agents.items()}
infos = {name: {} for name, _ in self.agents.items()}
terminateds["__all__"] = len(self.terminateds) == len(self.agents)
truncateds["__all__"] = self.game.calculate_truncated()
if self.game.save_step_metadata:
self._write_step_metadata_json(step, actions, state, rewards)
return next_obs, rewards, terminateds, truncateds, infos
def _write_step_metadata_json(self, step: int, actions: Dict, state: Dict, rewards: Dict):
output_dir = SIM_OUTPUT.path / f"episode_{self.episode_counter}" / "step_metadata"
output_dir.mkdir(parents=True, exist_ok=True)
path = output_dir / f"step_{step}.json"
data = {
"episode": self.episode_counter,
"step": step,
"actions": {agent_name: int(action) for agent_name, action in actions.items()},
"reward": rewards,
"state": state,
}
with open(path, "w") as file:
json.dump(data, file)
def _get_obs(self) -> Dict[str, ObsType]:
"""Return the current observation."""
obs = {}
for agent_name in self._agent_ids:
agent = self.game.rl_agents[agent_name]
unflat_space = agent.observation_manager.space
unflat_obs = agent.observation_manager.current_observation
obs[agent_name] = gymnasium.spaces.flatten(unflat_space, unflat_obs)
return obs
def close(self):
"""Close the simulation."""
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.history for name, agent in self.game.agents.items()}
self.io.write_agent_log(agent_actions=all_agent_actions, episode=self.episode_counter)
class PrimaiteRayEnv(gymnasium.Env):
"""Ray wrapper that accepts a single `env_config` parameter in init function for compatibility with Ray."""
def __init__(self, env_config: Dict) -> None:
"""Initialise the environment.
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
def reset(self, *, seed: int = None, options: dict = None) -> Tuple[ObsType, Dict]:
"""Reset the environment."""
return self.env.reset(seed=seed)
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict]:
"""Perform a step in the environment."""
return self.env.step(action)
def close(self):
"""Close the simulation."""
self.env.close()
@property
def game(self) -> PrimaiteGame:
"""Pass through game from env."""
return self.env.game

View File

@@ -21,6 +21,8 @@ class PacketCapture:
The PCAPs are logged to: <simulation output directory>/<hostname>/<hostname>_<ip address>_pcap.log
"""
_logger_instances: List[logging.Logger] = []
def __init__(
self,
hostname: str,
@@ -65,10 +67,12 @@ class PacketCapture:
if outbound:
self.outbound_logger = logging.getLogger(self._get_logger_name(outbound))
PacketCapture._logger_instances.append(self.outbound_logger)
logger = self.outbound_logger
else:
self.inbound_logger = logging.getLogger(self._get_logger_name(outbound))
logger = self.inbound_logger
PacketCapture._logger_instances.append(self.inbound_logger)
logger.setLevel(60) # Custom log level > CRITICAL to prevent any unwanted standard DEBUG-CRITICAL logs
logger.addHandler(file_handler)
@@ -122,3 +126,13 @@ class PacketCapture:
if SIM_OUTPUT.save_pcap_logs:
msg = frame.model_dump_json()
self.outbound_logger.log(level=60, msg=msg) # Log at custom log level > CRITICAL
@staticmethod
def clear():
"""Close all open PCAP file handlers."""
for logger in PacketCapture._logger_instances:
handlers = logger.handlers[:]
for handler in handlers:
logger.removeHandler(handler)
handler.close()
PacketCapture._logger_instances = []

View File

@@ -3,7 +3,7 @@ import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from primaite.session.environment import PrimaiteRayMARLEnv
from primaite.session.ray_envs import PrimaiteRayMARLEnv
from tests import TEST_ASSETS_ROOT
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"

View File

@@ -8,7 +8,7 @@ from ray.rllib.algorithms import ppo
from primaite.config.load import data_manipulation_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayEnv
from primaite.session.ray_envs import PrimaiteRayEnv
@pytest.mark.skip(reason="Slow, reenable later")

View File

@@ -4,7 +4,8 @@ import yaml
from gymnasium.core import ObsType
from numpy import ndarray
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayMARLEnv
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.ray_envs import PrimaiteRayMARLEnv
from primaite.simulator.network.hardware.nodes.host.server import Printer
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from tests import TEST_ASSETS_ROOT

View File

@@ -1,7 +1,8 @@
import pytest
import yaml
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.environment import PrimaiteGymEnv
from primaite.session.ray_envs import PrimaiteRayEnv, PrimaiteRayMARLEnv
from tests.conftest import TEST_ASSETS_ROOT
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"

View File

@@ -1,6 +1,6 @@
import yaml
from primaite.game.agent.interface import AgentActionHistoryItem
from primaite.game.agent.interface import AgentHistoryItem
from primaite.game.agent.rewards import GreenAdminDatabaseUnreachablePenalty, WebpageUnavailablePenalty
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteGymEnv
@@ -75,7 +75,7 @@ def test_uc2_rewards(game_and_agent):
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentActionHistoryItem(
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)
@@ -91,7 +91,7 @@ def test_uc2_rewards(game_and_agent):
state = game.get_sim_state()
reward_value = comp.calculate(
state,
last_action_response=AgentActionHistoryItem(
last_action_response=AgentHistoryItem(
timestep=0, action="NODE_APPLICATION_EXECUTE", parameters={}, request=["execute"], response=response
),
)