#2374 Remove primaite session

This commit is contained in:
Marek Wolan
2024-04-16 11:26:17 +01:00
parent 72dd84886b
commit 8d0d323e0b
33 changed files with 122 additions and 1700 deletions

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import json
import platform

View File

@@ -5,8 +5,7 @@
PrimAITE |VERSION| Configuration
********************************
PrimAITE uses a single configuration file to define everything needed to train and evaluate an RL policy in a custom cybersecurity scenario. This includes the configuration of the network, the scripted or trained agents that interact with the network, as well as settings that define how to perform training in Stable Baselines 3 or Ray RLLib.
The entire config is used by the ``PrimaiteSession`` object for users who wish to let PrimAITE handle the agent definition and training. If you wish to define custom agents and control the training loop yourself, you can use the config with the ``PrimaiteGame``, and ``PrimaiteGymEnv`` objects instead. That way, only the network configuration and agent setup parts of the config are used, and the training section is ignored.
PrimAITE uses a single configuration file to define everything needed to create the training environment for RL agents, including the network, the scripted agents, and the RL agent's action space, observation space, and reward function.
Example Configuration Hierarchy
###############################
@@ -14,8 +13,6 @@ The top level configuration items in a configuration file is as follows
.. code-block:: yaml
training_config:
...
io_settings:
...
game:
@@ -33,7 +30,6 @@ Configurable items
.. toctree::
:maxdepth: 1
configuration/training_config.rst
configuration/io_settings.rst
configuration/game.rst
configuration/agents.rst

View File

@@ -13,42 +13,12 @@ This section configures how PrimAITE saves data during simulation and training.
.. code-block:: yaml
io_settings:
save_final_model: True
save_checkpoints: False
checkpoint_interval: 10
# save_logs: True
# save_transactions: False
save_agent_actions: True
save_step_metadata: False
save_pcap_logs: False
save_sys_logs: False
``save_final_model``
--------------------
Optional. Default value is ``True``.
Only used if training with PrimaiteSession.
If ``True``, the policy will be saved after the final training iteration.
``save_checkpoints``
--------------------
Optional. Default value is ``False``.
Only used if training with PrimaiteSession.
If ``True``, the policy will be saved periodically during training.
``checkpoint_interval``
-----------------------
Optional. Default value is ``10``.
Only used if training with PrimaiteSession and if ``save_checkpoints`` is ``True``.
Defines how often to save the policy during training.
``save_logs``
-------------

View File

@@ -1,75 +0,0 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
``training_config``
===================
Configuration items relevant to how the Reinforcement Learning agent(s) will be trained.
``training_config`` hierarchy
-----------------------------
.. code-block:: yaml
training_config:
rl_framework: SB3 # or RLLIB_single_agent or RLLIB_multi_agent
rl_algorithm: PPO # or A2C
n_learn_episodes: 5
max_steps_per_episode: 200
n_eval_episodes: 1
deterministic_eval: True
seed: 123
``rl_framework``
----------------
The RL (Reinforcement Learning) Framework to use in the training session
Options available are:
- ``SB3`` (Stable Baselines 3)
- ``RLLIB_single_agent`` (Single Agent Ray RLLib)
- ``RLLIB_multi_agent`` (Multi Agent Ray RLLib)
``rl_algorithm``
----------------
The Reinforcement Learning Algorithm to use in the training session
Options available are:
- ``PPO`` (Proximal Policy Optimisation)
- ``A2C`` (Advantage Actor Critic)
``n_learn_episodes``
--------------------
The number of episodes to train the agent(s).
This should be an integer value above ``0``
``max_steps_per_episode``
-------------------------
The number of steps each episode will last for.
This should be an integer value above ``0``.
``n_eval_episodes``
-------------------
Optional. Default value is ``0``.
The number of evaluation episodes to run the trained agent for.
This should be an integer value above ``0``.
``deterministic_eval``
----------------------
Optional. By default this value is ``False``.
If this is set to ``True``, the agents will act deterministically instead of stochastically.
``seed``
--------
Optional.
The seed is used (alongside ``deterministic_eval``) to reproduce a previous instance of training and evaluation of an RL agent.
The seed should be an integer value.
Useful for debugging.

View File

@@ -10,15 +10,6 @@ The simulator and game layer communicate using the PrimAITE State API and the Pr
The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components:
PrimAITE Session
================
.. admonition:: Deprecated
:class: deprecated
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
Agents
======

View File

@@ -1,41 +0,0 @@
.. only:: comment
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
.. _run a primaite session:
.. admonition:: Deprecated
:class: deprecated
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
Run a PrimAITE Session
======================
``PrimaiteSession`` allows the user to train or evaluate an RL agent on the primaite simulation with just a config file,
no code required. It manages the lifecycle of a training or evaluation session, including the setup of the environment,
policy, simulator, agents, and IO.
If you want finer control over the RL policy, you can interface with the :py:module::`primaite.session.environment`
module directly without running a session.
Run
---
A PrimAITE session can be started either with the ``primaite session`` command from the cli
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
There are two parameters that can be specified:
- ``--config``: The path to the config file to use. If not specified, the default config file is used.
- ``--agent-load-file``: The path to the pre-trained agent to load. If not specified, a new agent is created.
Outputs
-------
Running a session creates a session output directory in your user data folder. The filepath looks like this:
``~/primaite/{VERSION}/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that
contains the action, reward, and simulation state. These can be found in
``~/primaite/{VERSION}/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_<n>/step_metadata/step_<n>.json``

View File

@@ -114,23 +114,3 @@ def setup(overwrite_existing: bool = True) -> None:
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
config: Optional[str] = None,
agent_load_file: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from primaite.config.load import data_manipulation_config_path
from primaite.main import run
if not config:
config = data_manipulation_config_path()
print(config)
run(config_path=config, agent_load_path=agent_load_file)

View File

@@ -1,17 +1,3 @@
training_config:
rl_framework: RLLIB_multi_agent
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 2
agent_references:
- defender_1
- defender_2
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -210,8 +210,8 @@ class PrimaiteGame:
"""Create a PrimaiteGame object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteGame.
1. io_settings: options for logging data during training
2. game_config: options for the game itself, such as agents.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.

View File

@@ -1,47 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import data_manipulation_config_path, load
from primaite.session.session import PrimaiteSession
# from primaite.primaite_session import PrimaiteSession
_LOGGER = getLogger(__name__)
def run(
config_path: Optional[Union[str, Path]] = "",
agent_load_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config")
args = parser.parse_args()
if not args.config:
args.config = data_manipulation_config_path()
run(args.config)

View File

@@ -48,7 +48,7 @@ class PrimaiteGymEnv(gymnasium.Env):
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
# make ProxyAgent store the action chosen by the RL policy
step = self.game.step_counter
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored

View File

@@ -95,6 +95,7 @@ class PrimaiteIO:
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":
"""Create an instance of PrimaiteIO based on a configuration dict."""
config = config or {}
new = cls(settings=cls.Settings(**config))
return new

View File

@@ -1,4 +0,0 @@
from primaite.session.policy.rllib import RaySingleAgentPolicy
from primaite.session.policy.sb3 import SB3Policy
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]

View File

@@ -1,82 +0,0 @@
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
"""Base class for reinforcement learning agents."""
_registry: Dict[str, Type["PolicyABC"]] = {}
"""
Registry of policy types, keyed by name.
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register a policy subclass.
:param name: Identifier used by from_config to create an instance of the policy.
:type name: str
:raises ValueError: When attempting to create a policy with a duplicate name.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate policy name {identifier}")
cls._registry[identifier] = cls
return
@abstractmethod
def __init__(self, session: "PrimaiteSession") -> None:
"""
Initialize a reinforcement learning policy.
:param session: The session context.
:type session: PrimaiteSession
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteSession" = session
"""Reference to the session."""
@abstractmethod
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
pass
@abstractmethod
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
"""Evaluate the agent."""
pass
@abstractmethod
def save(self, save_path: Path) -> None:
"""Save the agent."""
pass
@abstractmethod
def load(self) -> None:
"""Load agent from a file."""
pass
def close(self) -> None:
"""Close the agent."""
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.
Subclasses should not call super().from_config(), they should just handle creation form config.
"""
# Assume that basically the contents of training_config are passed into here.
# I should really define a config schema class using pydantic.
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)

View File

@@ -1,111 +0,0 @@
from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
from primaite import getLogger
_LOGGER = getLogger(__name__)
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
self.config = {
"env": PrimaiteRayEnv,
"env_config": {"game": session.game},
"disable_env_checking": True,
"num_rollout_workers": 0,
}
ray.shutdown()
ray.init()
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
self.config["training_iterations"] = n_episodes * timesteps_per_episode
self.config["train_batch_size"] = 128
self._algo = ppo.PPO(config=self.config)
_LOGGER.info("Starting RLLIB training session")
self._algo.train()
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
for ep in range(n_episodes):
obs, info = self.session.env.reset()
for step in range(self.session.game.options.max_episode_length):
action = self._algo.compute_single_action(observation=obs, explore=False)
obs, rew, term, trunc, info = self.session.env.step(action)
def save(self, save_path: Path) -> None:
"""Save the policy to a file."""
self._algo.save(save_path)
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
raise NotImplementedError
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
"""Mutli agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
"""Initialise multi agent policy wrapper."""
super().__init__(session=session)
self.config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in session.game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": n_episodes * timesteps_per_episode},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
),
param_space=self.config,
).fit()
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
return NotImplemented
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate trained policy."""
return NotImplemented
def save(self, save_path: Path) -> None:
"""Save policy parameters to a file."""
return NotImplemented
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
"""Create policy from config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,79 +0,0 @@
"""Stable baselines 3 policy."""
from pathlib import Path
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
self._agent_class: Type[Union[PPO, A2C]]
if algorithm == "PPO":
self._agent_class = PPO
policy = PPO_MLP
elif algorithm == "A2C":
self._agent_class = A2C
policy = A2C_MLP
else:
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
self._agent = self._agent_class(
policy=policy,
env=self.session.env,
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
seed=seed,
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
if self.session.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)
else:
checkpoint_callback = None
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
_ = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
def save(self, save_path: Path) -> None:
"""
Save the current policy parameters.
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
Therefore, this method is only used to save the final model.
"""
self._agent.save(save_path)
def load(self, model_path: Path) -> None:
"""Load agent from a checkpoint."""
self._agent = self._agent_class.load(model_path, env=self.session.env)
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
"""Create an agent from config file."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,119 +0,0 @@
# raise DeprecationWarning("This module is deprecated")
from enum import Enum
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import PrimaiteIO
# from primaite.game.game import PrimaiteGame
from primaite.session.policy.policy import PolicyABC
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self, game_cfg: Dict):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
"""The environment that the RL algorithm can consume."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.io_manager: Optional["PrimaiteIO"] = None
"""IO manager for the session."""
self.game_cfg: Dict = game_cfg
"""Primaite Game object for managing main simulation loop and agents."""
self.save_checkpoints: bool = False
"""Whether to save checkpoints."""
self.checkpoint_interval: int = 10
"""If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes."""
def start_session(self) -> None:
"""Commence the training/eval session."""
print("Starting Primaite Session")
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {}))
sess = cls(game_cfg=cfg)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])
sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints")
sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval")
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config=cfg)
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
elif sess.training_options.rl_framework == "SB3":
sess.env = PrimaiteGymEnv(game_config=cfg)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess

View File

@@ -271,7 +271,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
# Update the state with information from Layer3Interface
state.update(Layer3Interface.describe_state(self))
state["frequency"] = self.frequency
state["frequency"] = self.frequency.value
return state

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import json
from pathlib import Path

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Tuple, Union

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import csv
from logging import Logger

View File

@@ -1,12 +1,3 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
se3ed: 333 # Purposeful typo to check that error is raised with bad configuration.
n_learn_steps: 2560
n_eval_episodes: 5
game:
ports:
- ARP

View File

@@ -1,16 +1,3 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 0
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
game:
ports:
- ARP

View File

@@ -30,18 +30,6 @@
# | external_computer |------| switch_3 |------| external_server |
# ----------------------- -------------- ---------------------
#
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_step_metadata: false
save_pcap_logs: true

View File

@@ -1,21 +1,3 @@
training_config:
rl_framework: RLLIB_multi_agent
rl_algorithm: PPO
seed: 333
n_learn_episodes: 2
n_eval_episodes: 1
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references: #not used :(
- defender1
- defender2
io_settings:
save_checkpoints: true
checkpoint_interval: 5
game:
max_episode_length: 128
ports:
@@ -31,11 +13,12 @@ game:
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
type: PeriodicAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
@@ -901,86 +884,6 @@ agents:
options:
target_router_nodename: router_1
position: 9
38:
action: "HOST_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
39:
action: "HOST_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
40:
action: "HOST_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
41:
action: "HOST_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
42:
action: "HOST_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
43:
action: "HOST_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
44:
action: "HOST_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
45:
action: "HOST_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
46:
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
47:
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
48:
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
49:
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
50:
action: "HOST_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
51:
action: "HOST_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
52:
action: "HOST_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
53:
action: "HOST_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
options:

View File

@@ -1,15 +1,3 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_agent_actions: false
save_step_metadata: false

View File

@@ -1,15 +1,3 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_agent_actions: true
save_step_metadata: false

View File

@@ -1,15 +1,3 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 10
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_agent_actions: true
save_step_metadata: true
@@ -568,7 +556,7 @@ agents:
agent_settings:
# ...
flatten_obs: true

View File

@@ -1,737 +0,0 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 10
n_eval_episodes: 0
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
game:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_2_green_user
team: GREEN
type: ProbabilisticAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
action_map:
0:
action: DONOTHING
options: {}
options:
nodes:
- node_name: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: data_manipulation_attacker
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
- type: NODE_OS_SCAN
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
options:
nodes:
- node_name: client_1
applications:
- application_name: DataManipulationBot
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_settings:
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
hosts:
- hostname: domain_controller
- hostname: web_server
services:
- service_name: WebServer
- hostname: database_server
folders:
- folder_name: database
files:
- file_name: database.db
- hostname: backup_server
- hostname: security_suite
- hostname: client_1
- hostname: client_2
num_services: 1
num_applications: 0
num_folders: 1
num_files: 1
num_nics: 2
include_num_access: false
include_nmne: true
routers:
- hostname: router_1
num_ports: 0
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
wildcard_list:
- 0.0.0.1
port_list:
- 80
- 5432
protocol_list:
- ICMP
- TCP
- UDP
num_rules: 10
- type: LINKS
label: LINKS
options:
link_references:
- router_1:eth-1<->switch_1:eth-8
- router_1:eth-2<->switch_2:eth-8
- switch_1:eth-1<->domain_controller:eth-1
- switch_1:eth-2<->web_server:eth-1
- switch_1:eth-3<->database_server:eth-1
- switch_1:eth-4<->backup_server:eth-1
- switch_1:eth-7<->security_suite:eth-1
- switch_2:eth-1<->client_1:eth-1
- switch_2:eth-2<->client_2:eth-1
- switch_2:eth-7<->security_suite:eth-2
- type: "NONE"
label: ICS
options: {}
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_SERVICE_FIX
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: ROUTER_ACL_ADDRULE
- type: ROUTER_ACL_REMOVERULE
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 1
service_id: 0
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 1
service_id: 0
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 1
service_id: 0
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 1
service_id: 0
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 1
service_id: 0
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 1
service_id: 0
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 1
service_id: 0
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 1
service_id: 0
9: # check database.db file
action: "NODE_FILE_SCAN"
options:
node_id: 2
folder_id: 1
file_id: 0
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 2
folder_id: 1
file_id: 0
11:
action: "NODE_FILE_DELETE"
options:
node_id: 2
folder_id: 1
file_id: 0
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 2
folder_id: 1
file_id: 0
13:
action: "NODE_SERVICE_FIX"
options:
node_id: 2
service_id: 0
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 2
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 2
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 2
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 2
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 2
19: # shutdown client 1
action: "NODE_SHUTDOWN"
options:
node_id: 5
20:
action: "NODE_STARTUP"
options:
node_id: 5
21:
action: "NODE_RESET"
options:
node_id: 5
22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite)
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 1
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 2
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 1 # ALL
source_port_id: 1
dest_port_id: 1
protocol_id: 1
source_wildcard_id: 0
dest_wildcard_id: 0
24: # block tcp traffic from client 1 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 3
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
25: # block tcp traffic from client 2 to web app
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 4
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 3 # web server
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
26:
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 5
permission: 2
source_ip_id: 7 # client 1
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
27:
action: "ROUTER_ACL_ADDRULE"
options:
target_router_nodename: router_1
position: 6
permission: 2
source_ip_id: 8 # client 2
dest_ip_id: 4 # database
source_port_id: 1
dest_port_id: 1
protocol_id: 3
source_wildcard_id: 0
dest_wildcard_id: 0
28:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 0
29:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 1
30:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 2
31:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 3
32:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 4
33:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 5
34:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 6
35:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 7
36:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 8
37:
action: "ROUTER_ACL_REMOVERULE"
options:
target_router_nodename: router_1
position: 9
38:
action: "HOST_NIC_DISABLE"
options:
node_id: 0
nic_id: 0
39:
action: "HOST_NIC_ENABLE"
options:
node_id: 0
nic_id: 0
40:
action: "HOST_NIC_DISABLE"
options:
node_id: 1
nic_id: 0
41:
action: "HOST_NIC_ENABLE"
options:
node_id: 1
nic_id: 0
42:
action: "HOST_NIC_DISABLE"
options:
node_id: 2
nic_id: 0
43:
action: "HOST_NIC_ENABLE"
options:
node_id: 2
nic_id: 0
44:
action: "HOST_NIC_DISABLE"
options:
node_id: 3
nic_id: 0
45:
action: "HOST_NIC_ENABLE"
options:
node_id: 3
nic_id: 0
46:
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 0
47:
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 0
48:
action: "HOST_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
49:
action: "HOST_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
50:
action: "HOST_NIC_DISABLE"
options:
node_id: 5
nic_id: 0
51:
action: "HOST_NIC_ENABLE"
options:
node_id: 5
nic_id: 0
52:
action: "HOST_NIC_DISABLE"
options:
node_id: 6
nic_id: 0
53:
action: "HOST_NIC_ENABLE"
options:
node_id: 6
nic_id: 0
options:
nodes:
- node_name: domain_controller
- node_name: web_server
- node_name: database_server
- node_name: backup_server
- node_name: security_suite
- node_name: client_1
- node_name: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
ip_list:
- 192.168.1.10
- 192.168.1.12
- 192.168.1.14
- 192.168.1.16
- 192.168.1.110
- 192.168.10.21
- 192.168.10.22
- 192.168.10.110
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_hostname: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_hostname: web_server
service_name: web_service
agent_settings:
# ...
simulation:
network:
nodes:
- type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- type: switch
hostname: switch_1
num_ports: 8
- type: switch
hostname: switch_2
num_ports: 8
- type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: WebServer
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: DatabaseService
- type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- type: FTPServer
- type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
network_interfaces:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: DataManipulationBot
options:
port_scan_p_of_success: 0.1
data_manipulation_p_of_success: 0.1
payload: "DELETE"
server_ip: 192.168.1.14
services:
- type: DNSClient
- type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- type: WebBrowser
services:
- type: DNSClient
links:
- endpoint_a_hostname: router_1
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 8
- endpoint_a_hostname: router_1
endpoint_a_port: 2
endpoint_b_hostname: switch_2
endpoint_b_port: 8
- endpoint_a_hostname: switch_1
endpoint_a_port: 1
endpoint_b_hostname: domain_controller
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 2
endpoint_b_hostname: web_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 3
endpoint_b_hostname: database_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 4
endpoint_b_hostname: backup_server
endpoint_b_port: 1
- endpoint_a_hostname: switch_1
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 1
endpoint_b_hostname: client_1
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 2
endpoint_b_hostname: client_2
endpoint_b_port: 1
- endpoint_a_hostname: switch_2
endpoint_a_port: 7
endpoint_b_hostname: security_suite
endpoint_b_port: 2

View File

@@ -13,7 +13,6 @@ from primaite.game.agent.interface import AbstractAgent
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.game import PrimaiteGame
from primaite.session.session import PrimaiteSession
from primaite.simulator import SIM_OUTPUT
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.container import Network
@@ -121,38 +120,6 @@ def file_system() -> FileSystem:
return computer.file_system
# PrimAITE v2 stuff
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.
Uses context manager for deletion of files upon exit.
"""
@classmethod
def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession":
"""Create a temporary PrimaiteSession object from a config file."""
config_path = Path(config_path)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
return super().from_config(cfg=config)
def __enter__(self):
return self
def __exit__(self, type, value, tb):
pass
@pytest.fixture
def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
"""Create a temporary PrimaiteSession object."""
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
config_path = request.param[0]
return TempPrimaiteSession.from_config(config_path=config_path)
@pytest.fixture(scope="function")
def client_server() -> Tuple[Computer, Server]:
network = Network()

View File

@@ -1,33 +1,28 @@
import pytest
import ray
import yaml
from ray import air, tune
from ray.rllib.algorithms.ppo import PPOConfig
from primaite.config.load import data_manipulation_config_path
from primaite.game.game import PrimaiteGame
from primaite.session.environment import PrimaiteRayMARLEnv
from tests import TEST_ASSETS_ROOT
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
@pytest.mark.skip(reason="Slow, reenable later")
def test_rllib_multi_agent_compatibility():
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
with open(data_manipulation_config_path(), "r") as f:
with open(MULTI_AGENT_PATH, "r") as f:
cfg = yaml.safe_load(f)
game = PrimaiteGame.from_config(cfg)
ray.shutdown()
ray.init()
env_config = {"game": game}
config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": game})
.environment(env=PrimaiteRayMARLEnv, env_config=cfg)
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in game.rl_agents},
policies={agent["ref"] for agent in cfg["agents"]},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)

View File

@@ -0,0 +1,91 @@
import pydantic
import pytest
import yaml
from gymnasium.core import ObsType
from numpy import ndarray
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayMARLEnv
from primaite.simulator.network.hardware.nodes.host.server import Printer
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from tests import TEST_ASSETS_ROOT
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
class TestPrimaiteEnvironment:
def test_creating_env(self):
"""Check that environment loads correctly from config and it can be reset."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
def env_checks():
assert env is not None
assert env.game.simulation
assert len(env.game.agents) == 3
assert len(env.game.rl_agents) == 1
assert env.game.simulation.network
assert len(env.game.simulation.network.nodes) == 12
wireless = env.game.simulation.network.get_node_by_hostname("router_2")
assert isinstance(wireless, WirelessRouter)
printer = env.game.simulation.network.get_node_by_hostname("HP_LaserJet_Pro_4102fdn_printer")
assert isinstance(printer, Printer)
env_checks()
env.reset()
env_checks()
def test_step_env(self):
"""Make sure you can go all the way through the session without errors."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
assert (num_actions := len(env.agent.action_manager.action_map)) == 54
# run every action and make sure there's no crash
for act in range(num_actions):
env.step(act)
# try running action number outside the action map to check that it fails.
with pytest.raises(KeyError):
env.step(num_actions)
obs, rew, trunc, term, info = env.step(0)
assert isinstance(obs, ndarray)
def test_multi_agent_env(self):
"""Check that we can run a training session with a multi agent system."""
with open(MULTI_AGENT_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteRayMARLEnv(env_config=cfg)
assert set(env._agent_ids) == {"defender1", "defender2"}
assert len(env.agents) == 2
defender1 = env.agents["defender1"]
defender2 = env.agents["defender2"]
assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54
assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38
# ensure we can run all valid actions without error
for act_1 in range(num_actions_1):
env.step({"defender1": act_1, "defender2": 0})
for act_2 in range(num_actions_2):
env.step({"defender1": 0, "defender2": act_2})
# ensure we get error when taking an invalid action
with pytest.raises(KeyError):
env.step({"defender1": num_actions_1, "defender2": 0})
with pytest.raises(KeyError):
env.step({"defender1": 0, "defender2": num_actions_2})
def test_error_thrown_on_bad_configuration(self):
"""Make sure we throw an error when the config is bad."""
with open(MISCONFIGURED_PATH, "r") as f:
cfg = yaml.safe_load(f)
with pytest.raises(pydantic.ValidationError):
env = PrimaiteGymEnv(game_config=cfg)

View File

@@ -1,109 +0,0 @@
import pydantic
import pytest
from primaite.simulator.network.hardware.nodes.host.server import Printer
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
from tests import TEST_ASSETS_ROOT
from tests.conftest import TempPrimaiteSession
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
class TestPrimaiteSession:
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_creating_session(self, temp_primaite_session):
"""Check that creating a session from config works."""
with temp_primaite_session as session:
if not isinstance(session, TempPrimaiteSession):
raise AssertionError
assert session is not None
assert session.env.game.simulation
assert len(session.env.game.agents) == 3
assert len(session.env.game.rl_agents) == 1
assert session.policy
assert session.env
assert session.env.game.simulation.network
assert len(session.env.game.simulation.network.nodes) == 12
wireless = session.env.game.simulation.network.get_node_by_hostname("router_2")
assert isinstance(wireless, WirelessRouter)
printer = session.env.game.simulation.network.get_node_by_hostname("HP_LaserJet_Pro_4102fdn_printer")
assert isinstance(printer, Printer)
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):
"""Make sure you can go all the way through the session without errors."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
session_path = session.io_manager.session_path
assert session_path.exists()
print(list(session_path.glob("*")))
checkpoint_dir = session_path / "checkpoints" / "sb3_final"
assert checkpoint_dir.exists()
checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip"
checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip"
checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip"
assert checkpoint_1.exists()
assert checkpoint_2.exists()
assert not checkpoint_3.exists()
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
def test_training_only_session(self, temp_primaite_session):
"""Check that you can run a training-only session."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
def test_eval_only_session(self, temp_primaite_session):
"""Check that you can load a model and run an eval-only session."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
# TODO: include checks that the model was loaded and that the eval-only session ran
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.skip(reason="Slow, reenable later")
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
def test_multi_agent_session(self, temp_primaite_session):
"""Check that we can run a training session with a multi agent system."""
with temp_primaite_session as session:
session.start_session()
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
@pytest.mark.skip(
reason="Currently software cannot be dynamically created/destroyed during simulation. Therefore, "
"reset doesn't implement software restore."
)
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:
session: TempPrimaiteSession
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.uninstall("DataManipulationBot")
assert "DataManipulationBot" not in client_1.software_manager.software
session.game.reset()
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
assert "DataManipulationBot" in client_1.software_manager.software