#2374 Remove primaite session
This commit is contained in:
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
import platform
|
||||
|
||||
@@ -5,8 +5,7 @@
|
||||
PrimAITE |VERSION| Configuration
|
||||
********************************
|
||||
|
||||
PrimAITE uses a single configuration file to define everything needed to train and evaluate an RL policy in a custom cybersecurity scenario. This includes the configuration of the network, the scripted or trained agents that interact with the network, as well as settings that define how to perform training in Stable Baselines 3 or Ray RLLib.
|
||||
The entire config is used by the ``PrimaiteSession`` object for users who wish to let PrimAITE handle the agent definition and training. If you wish to define custom agents and control the training loop yourself, you can use the config with the ``PrimaiteGame``, and ``PrimaiteGymEnv`` objects instead. That way, only the network configuration and agent setup parts of the config are used, and the training section is ignored.
|
||||
PrimAITE uses a single configuration file to define everything needed to create the training environment for RL agents, including the network, the scripted agents, and the RL agent's action space, observation space, and reward function.
|
||||
|
||||
Example Configuration Hierarchy
|
||||
###############################
|
||||
@@ -14,8 +13,6 @@ The top level configuration items in a configuration file is as follows
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
training_config:
|
||||
...
|
||||
io_settings:
|
||||
...
|
||||
game:
|
||||
@@ -33,7 +30,6 @@ Configurable items
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
||||
configuration/training_config.rst
|
||||
configuration/io_settings.rst
|
||||
configuration/game.rst
|
||||
configuration/agents.rst
|
||||
|
||||
@@ -13,42 +13,12 @@ This section configures how PrimAITE saves data during simulation and training.
|
||||
.. code-block:: yaml
|
||||
|
||||
io_settings:
|
||||
save_final_model: True
|
||||
save_checkpoints: False
|
||||
checkpoint_interval: 10
|
||||
# save_logs: True
|
||||
# save_transactions: False
|
||||
save_agent_actions: True
|
||||
save_step_metadata: False
|
||||
save_pcap_logs: False
|
||||
save_sys_logs: False
|
||||
|
||||
``save_final_model``
|
||||
--------------------
|
||||
|
||||
Optional. Default value is ``True``.
|
||||
|
||||
Only used if training with PrimaiteSession.
|
||||
If ``True``, the policy will be saved after the final training iteration.
|
||||
|
||||
|
||||
``save_checkpoints``
|
||||
--------------------
|
||||
|
||||
Optional. Default value is ``False``.
|
||||
|
||||
Only used if training with PrimaiteSession.
|
||||
If ``True``, the policy will be saved periodically during training.
|
||||
|
||||
|
||||
``checkpoint_interval``
|
||||
-----------------------
|
||||
|
||||
Optional. Default value is ``10``.
|
||||
|
||||
Only used if training with PrimaiteSession and if ``save_checkpoints`` is ``True``.
|
||||
Defines how often to save the policy during training.
|
||||
|
||||
|
||||
``save_logs``
|
||||
-------------
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
``training_config``
|
||||
===================
|
||||
Configuration items relevant to how the Reinforcement Learning agent(s) will be trained.
|
||||
|
||||
``training_config`` hierarchy
|
||||
-----------------------------
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
training_config:
|
||||
rl_framework: SB3 # or RLLIB_single_agent or RLLIB_multi_agent
|
||||
rl_algorithm: PPO # or A2C
|
||||
n_learn_episodes: 5
|
||||
max_steps_per_episode: 200
|
||||
n_eval_episodes: 1
|
||||
deterministic_eval: True
|
||||
seed: 123
|
||||
|
||||
|
||||
``rl_framework``
|
||||
----------------
|
||||
The RL (Reinforcement Learning) Framework to use in the training session
|
||||
|
||||
Options available are:
|
||||
|
||||
- ``SB3`` (Stable Baselines 3)
|
||||
- ``RLLIB_single_agent`` (Single Agent Ray RLLib)
|
||||
- ``RLLIB_multi_agent`` (Multi Agent Ray RLLib)
|
||||
|
||||
``rl_algorithm``
|
||||
----------------
|
||||
The Reinforcement Learning Algorithm to use in the training session
|
||||
|
||||
Options available are:
|
||||
|
||||
- ``PPO`` (Proximal Policy Optimisation)
|
||||
- ``A2C`` (Advantage Actor Critic)
|
||||
|
||||
``n_learn_episodes``
|
||||
--------------------
|
||||
The number of episodes to train the agent(s).
|
||||
This should be an integer value above ``0``
|
||||
|
||||
``max_steps_per_episode``
|
||||
-------------------------
|
||||
The number of steps each episode will last for.
|
||||
This should be an integer value above ``0``.
|
||||
|
||||
|
||||
``n_eval_episodes``
|
||||
-------------------
|
||||
Optional. Default value is ``0``.
|
||||
|
||||
The number of evaluation episodes to run the trained agent for.
|
||||
This should be an integer value above ``0``.
|
||||
|
||||
``deterministic_eval``
|
||||
----------------------
|
||||
Optional. By default this value is ``False``.
|
||||
|
||||
If this is set to ``True``, the agents will act deterministically instead of stochastically.
|
||||
|
||||
|
||||
|
||||
``seed``
|
||||
--------
|
||||
Optional.
|
||||
|
||||
The seed is used (alongside ``deterministic_eval``) to reproduce a previous instance of training and evaluation of an RL agent.
|
||||
The seed should be an integer value.
|
||||
Useful for debugging.
|
||||
@@ -10,15 +10,6 @@ The simulator and game layer communicate using the PrimAITE State API and the Pr
|
||||
|
||||
The game layer is responsible for managing agents and getting them to interface with the simulator correctly. It consists of several components:
|
||||
|
||||
PrimAITE Session
|
||||
================
|
||||
|
||||
.. admonition:: Deprecated
|
||||
:class: deprecated
|
||||
|
||||
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The `session` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
|
||||
|
||||
``PrimaiteSession`` is the main entry point into Primaite and it allows the simultaneous coordination of a simulation and agents that interact with it. ``PrimaiteSession`` keeps track of multiple agents of different types.
|
||||
|
||||
Agents
|
||||
======
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
.. only:: comment
|
||||
|
||||
© Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
|
||||
.. _run a primaite session:
|
||||
|
||||
.. admonition:: Deprecated
|
||||
:class: deprecated
|
||||
|
||||
PrimAITE Session is being deprecated in favour of Jupyter Notebooks. The ``session`` command will be removed in future releases, but example notebooks will be provided to demonstrate the same functionality.
|
||||
|
||||
Run a PrimAITE Session
|
||||
======================
|
||||
|
||||
``PrimaiteSession`` allows the user to train or evaluate an RL agent on the primaite simulation with just a config file,
|
||||
no code required. It manages the lifecycle of a training or evaluation session, including the setup of the environment,
|
||||
policy, simulator, agents, and IO.
|
||||
|
||||
If you want finer control over the RL policy, you can interface with the :py:module::`primaite.session.environment`
|
||||
module directly without running a session.
|
||||
|
||||
|
||||
|
||||
Run
|
||||
---
|
||||
|
||||
A PrimAITE session can be started either with the ``primaite session`` command from the cli
|
||||
(See :func:`primaite.cli.session`), or by calling :func:`primaite.main.run` from a Python terminal or Jupyter Notebook.
|
||||
|
||||
There are two parameters that can be specified:
|
||||
- ``--config``: The path to the config file to use. If not specified, the default config file is used.
|
||||
- ``--agent-load-file``: The path to the pre-trained agent to load. If not specified, a new agent is created.
|
||||
|
||||
Outputs
|
||||
-------
|
||||
|
||||
Running a session creates a session output directory in your user data folder. The filepath looks like this:
|
||||
``~/primaite/{VERSION}/sessions/YYYY-MM-DD/HH-MM-SS/``. This folder contains the simulation sys logs generated by each node,
|
||||
the saved agent checkpoints, and final model. The folder also contains a .json file for each episode step that
|
||||
contains the action, reward, and simulation state. These can be found in
|
||||
``~/primaite/{VERSION}/sessions/YYYY-MM-DD/HH-MM-SS/simulation_output/episode_<n>/step_metadata/step_<n>.json``
|
||||
@@ -114,23 +114,3 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(
|
||||
config: Optional[str] = None,
|
||||
agent_load_file: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.main import run
|
||||
|
||||
if not config:
|
||||
config = data_manipulation_config_path()
|
||||
print(config)
|
||||
run(config_path=config, agent_load_path=agent_load_file)
|
||||
|
||||
@@ -1,17 +1,3 @@
|
||||
training_config:
|
||||
rl_framework: RLLIB_multi_agent
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 2
|
||||
agent_references:
|
||||
- defender_1
|
||||
- defender_2
|
||||
|
||||
|
||||
io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
|
||||
@@ -210,8 +210,8 @@ class PrimaiteGame:
|
||||
"""Create a PrimaiteGame object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent.
|
||||
2. game_config: options for the game itself. Used by PrimaiteGame.
|
||||
1. io_settings: options for logging data during training
|
||||
2. game_config: options for the game itself, such as agents.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
The specification for each of the three major areas is described in a separate documentation page.
|
||||
|
||||
@@ -1,47 +0,0 @@
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
"""The main PrimAITE session runner module."""
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.load import data_manipulation_config_path, load
|
||||
from primaite.session.session import PrimaiteSession
|
||||
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def run(
|
||||
config_path: Optional[Union[str, Path]] = "",
|
||||
agent_load_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
|
||||
:param training_config_path: YAML file containing configurable items defined in
|
||||
`primaite.config.training_config.TrainingConfig`
|
||||
:type training_config_path: Union[path, str]
|
||||
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
|
||||
:type lay_down_config_path: Union[path, str]
|
||||
:param session_path: directory path of the session to load
|
||||
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
|
||||
otherwise False.
|
||||
"""
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config")
|
||||
|
||||
args = parser.parse_args()
|
||||
if not args.config:
|
||||
args.config = data_manipulation_config_path()
|
||||
|
||||
run(args.config)
|
||||
@@ -48,7 +48,7 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
# make ProxyAgent store the action chosen by the RL policy
|
||||
step = self.game.step_counter
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
|
||||
@@ -95,6 +95,7 @@ class PrimaiteIO:
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict) -> "PrimaiteIO":
|
||||
"""Create an instance of PrimaiteIO based on a configuration dict."""
|
||||
config = config or {}
|
||||
new = cls(settings=cls.Settings(**config))
|
||||
|
||||
return new
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
from primaite.session.policy.rllib import RaySingleAgentPolicy
|
||||
from primaite.session.policy.sb3 import SB3Policy
|
||||
|
||||
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Base class and common logic for RL policies."""
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class PolicyABC(ABC):
|
||||
"""Base class for reinforcement learning agents."""
|
||||
|
||||
_registry: Dict[str, Type["PolicyABC"]] = {}
|
||||
"""
|
||||
Registry of policy types, keyed by name.
|
||||
|
||||
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a policy subclass.
|
||||
|
||||
:param name: Identifier used by from_config to create an instance of the policy.
|
||||
:type name: str
|
||||
:raises ValueError: When attempting to create a policy with a duplicate name.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate policy name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: "PrimaiteSession") -> None:
|
||||
"""
|
||||
Initialize a reinforcement learning policy.
|
||||
|
||||
:param session: The session context.
|
||||
:type session: PrimaiteSession
|
||||
:param agents: The agents to train.
|
||||
:type agents: List[RLAgent]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
"""Reference to the session."""
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
"""Load agent from a file."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the agent."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
|
||||
"""
|
||||
Create an RL policy from a config by calling the relevant subclass's from_config method.
|
||||
|
||||
Subclasses should not call super().from_config(), they should just handle creation form config.
|
||||
"""
|
||||
# Assume that basically the contents of training_config are passed into here.
|
||||
# I should really define a config schema class using pydantic.
|
||||
|
||||
PolicyType = cls._registry[config.rl_framework]
|
||||
return PolicyType.from_config(config=config, session=session)
|
||||
@@ -1,111 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, TYPE_CHECKING
|
||||
|
||||
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
import ray
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms import ppo
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
|
||||
"""Single agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
super().__init__(session=session)
|
||||
|
||||
self.config = {
|
||||
"env": PrimaiteRayEnv,
|
||||
"env_config": {"game": session.game},
|
||||
"disable_env_checking": True,
|
||||
"num_rollout_workers": 0,
|
||||
}
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
self.config["training_iterations"] = n_episodes * timesteps_per_episode
|
||||
self.config["train_batch_size"] = 128
|
||||
self._algo = ppo.PPO(config=self.config)
|
||||
_LOGGER.info("Starting RLLIB training session")
|
||||
self._algo.train()
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
for ep in range(n_episodes):
|
||||
obs, info = self.session.env.reset()
|
||||
for step in range(self.session.game.options.max_episode_length):
|
||||
action = self._algo.compute_single_action(observation=obs, explore=False)
|
||||
obs, rew, term, trunc, info = self.session.env.step(action)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the policy to a file."""
|
||||
self._algo.save(save_path)
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy parameters from a file."""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
|
||||
"""Create a policy from a config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
|
||||
|
||||
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
|
||||
"""Mutli agent RL policy using Ray RLLib."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
|
||||
"""Initialise multi agent policy wrapper."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self.config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.multi_agent(
|
||||
policies={agent.agent_name for agent in session.game.rl_agents},
|
||||
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
|
||||
tune.Tuner(
|
||||
"PPO",
|
||||
run_config=air.RunConfig(
|
||||
stop={"training_iteration": n_episodes * timesteps_per_episode},
|
||||
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
|
||||
),
|
||||
param_space=self.config,
|
||||
).fit()
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load policy parameters from a file."""
|
||||
return NotImplemented
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate trained policy."""
|
||||
return NotImplemented
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save policy parameters to a file."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
|
||||
"""Create policy from config."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
@@ -1,79 +0,0 @@
|
||||
"""Stable baselines 3 policy."""
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.session.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
"""Single agent RL policy using stable baselines 3."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
"""Initialize a stable baselines 3 policy."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self._agent_class: Type[Union[PPO, A2C]]
|
||||
if algorithm == "PPO":
|
||||
self._agent_class = PPO
|
||||
policy = PPO_MLP
|
||||
elif algorithm == "A2C":
|
||||
self._agent_class = A2C
|
||||
policy = A2C_MLP
|
||||
else:
|
||||
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
|
||||
self._agent = self._agent_class(
|
||||
policy=policy,
|
||||
env=self.session.env,
|
||||
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
if self.session.save_checkpoints:
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=timesteps_per_episode * self.session.checkpoint_interval,
|
||||
save_path=self.session.io_manager.generate_model_save_path("sb3"),
|
||||
name_prefix="sb3_model",
|
||||
)
|
||||
else:
|
||||
checkpoint_callback = None
|
||||
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
_ = evaluate_policy(
|
||||
self._agent,
|
||||
self.session.env,
|
||||
n_eval_episodes=n_episodes,
|
||||
deterministic=deterministic,
|
||||
return_episode_rewards=True,
|
||||
)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""
|
||||
Save the current policy parameters.
|
||||
|
||||
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
|
||||
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
|
||||
Therefore, this method is only used to save the final model.
|
||||
"""
|
||||
self._agent.save(save_path)
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load agent from a checkpoint."""
|
||||
self._agent = self._agent_class.load(model_path, env=self.session.env)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
|
||||
"""Create an agent from config file."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
@@ -1,119 +0,0 @@
|
||||
# raise DeprecationWarning("This module is deprecated")
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from primaite.session.io import PrimaiteIO
|
||||
|
||||
# from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.policy.policy import PolicyABC
|
||||
|
||||
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
|
||||
|
||||
def __init__(self, game_cfg: Dict):
|
||||
"""Initialise PrimaiteSession object."""
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
|
||||
"""The environment that the RL algorithm can consume."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.io_manager: Optional["PrimaiteIO"] = None
|
||||
"""IO manager for the session."""
|
||||
|
||||
self.game_cfg: Dict = game_cfg
|
||||
"""Primaite Game object for managing main simulation loop and agents."""
|
||||
|
||||
self.save_checkpoints: bool = False
|
||||
"""Whether to save checkpoints."""
|
||||
|
||||
self.checkpoint_interval: int = 10
|
||||
"""If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training/eval session."""
|
||||
print("Starting Primaite Session")
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary."""
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {}))
|
||||
|
||||
sess = cls(game_cfg=cfg)
|
||||
sess.io_manager = io_manager
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints")
|
||||
sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval")
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
if sess.training_options.rl_framework == "RLLIB_single_agent":
|
||||
sess.env = PrimaiteRayEnv(env_config=cfg)
|
||||
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
|
||||
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
|
||||
elif sess.training_options.rl_framework == "SB3":
|
||||
sess.env = PrimaiteGymEnv(game_config=cfg)
|
||||
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
return sess
|
||||
@@ -271,7 +271,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
|
||||
# Update the state with information from Layer3Interface
|
||||
state.update(Layer3Interface.describe_state(self))
|
||||
|
||||
state["frequency"] = self.frequency
|
||||
state["frequency"] = self.frequency.value
|
||||
|
||||
return state
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa
|
||||
raise DeprecationWarning(
|
||||
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
|
||||
)
|
||||
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
|
||||
import csv
|
||||
from logging import Logger
|
||||
|
||||
@@ -1,12 +1,3 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
se3ed: 333 # Purposeful typo to check that error is raised with bad configuration.
|
||||
n_learn_steps: 2560
|
||||
n_eval_episodes: 5
|
||||
|
||||
|
||||
|
||||
game:
|
||||
ports:
|
||||
- ARP
|
||||
|
||||
@@ -1,16 +1,3 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 0
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
|
||||
game:
|
||||
ports:
|
||||
- ARP
|
||||
|
||||
@@ -30,18 +30,6 @@
|
||||
# | external_computer |------| switch_3 |------| external_server |
|
||||
# ----------------------- -------------- ---------------------
|
||||
#
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: true
|
||||
|
||||
@@ -1,21 +1,3 @@
|
||||
training_config:
|
||||
rl_framework: RLLIB_multi_agent
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 2
|
||||
n_eval_episodes: 1
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references: #not used :(
|
||||
- defender1
|
||||
- defender2
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 128
|
||||
ports:
|
||||
@@ -31,11 +13,12 @@ game:
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
type: PeriodicAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
|
||||
options:
|
||||
nodes:
|
||||
@@ -901,86 +884,6 @@ agents:
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 9
|
||||
38:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
39:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
40:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
41:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
42:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
43:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
44:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
45:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
46:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
47:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
48:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
49:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
50:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
51:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
52:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
53:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
|
||||
|
||||
options:
|
||||
|
||||
@@ -1,15 +1,3 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_agent_actions: false
|
||||
save_step_metadata: false
|
||||
|
||||
@@ -1,15 +1,3 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 1
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
|
||||
@@ -1,15 +1,3 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 10
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: true
|
||||
@@ -568,7 +556,7 @@ agents:
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
flatten_obs: true
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,737 +0,0 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 10
|
||||
n_eval_episodes: 0
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
|
||||
game:
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_2_green_user
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_2
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: data_manipulation_attacker
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client_1
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_settings:
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
hosts:
|
||||
- hostname: domain_controller
|
||||
- hostname: web_server
|
||||
services:
|
||||
- service_name: WebServer
|
||||
- hostname: database_server
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- hostname: backup_server
|
||||
- hostname: security_suite
|
||||
- hostname: client_1
|
||||
- hostname: client_2
|
||||
num_services: 1
|
||||
num_applications: 0
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 2
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
routers:
|
||||
- hostname: router_1
|
||||
num_ports: 0
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
wildcard_list:
|
||||
- 0.0.0.1
|
||||
port_list:
|
||||
- 80
|
||||
- 5432
|
||||
protocol_list:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
num_rules: 10
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- router_1:eth-1<->switch_1:eth-8
|
||||
- router_1:eth-2<->switch_2:eth-8
|
||||
- switch_1:eth-1<->domain_controller:eth-1
|
||||
- switch_1:eth-2<->web_server:eth-1
|
||||
- switch_1:eth-3<->database_server:eth-1
|
||||
- switch_1:eth-4<->backup_server:eth-1
|
||||
- switch_1:eth-7<->security_suite:eth-1
|
||||
- switch_2:eth-1<->client_1:eth-1
|
||||
- switch_2:eth-2<->client_2:eth-1
|
||||
- switch_2:eth-7<->security_suite:eth-2
|
||||
- type: "NONE"
|
||||
label: ICS
|
||||
options: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_SERVICE_FIX
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: ROUTER_ACL_ADDRULE
|
||||
- type: ROUTER_ACL_REMOVERULE
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
service_id: 0
|
||||
9: # check database.db file
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
file_id: 0
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
file_id: 0
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
file_id: 0
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
file_id: 0
|
||||
13:
|
||||
action: "NODE_SERVICE_FIX"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 0
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 2
|
||||
folder_id: 1
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 2
|
||||
19: # shutdown client 1
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 5
|
||||
20:
|
||||
action: "NODE_STARTUP"
|
||||
options:
|
||||
node_id: 5
|
||||
21:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 5
|
||||
22: # "ACL: ADDRULE - Block outgoing traffic from client 1" (not supported in Primaite)
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
23: # "ACL: ADDRULE - Block outgoing traffic from client 2" (not supported in Primaite)
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 2
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 1 # ALL
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
24: # block tcp traffic from client 1 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 3
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
25: # block tcp traffic from client 2 to web app
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 4
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 3 # web server
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
26:
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 5
|
||||
permission: 2
|
||||
source_ip_id: 7 # client 1
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
27:
|
||||
action: "ROUTER_ACL_ADDRULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 6
|
||||
permission: 2
|
||||
source_ip_id: 8 # client 2
|
||||
dest_ip_id: 4 # database
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
source_wildcard_id: 0
|
||||
dest_wildcard_id: 0
|
||||
28:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 0
|
||||
29:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 1
|
||||
30:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 2
|
||||
31:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 3
|
||||
32:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 4
|
||||
33:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 5
|
||||
34:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 6
|
||||
35:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 7
|
||||
36:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 8
|
||||
37:
|
||||
action: "ROUTER_ACL_REMOVERULE"
|
||||
options:
|
||||
target_router_nodename: router_1
|
||||
position: 9
|
||||
38:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
39:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
40:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
41:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
42:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
43:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 0
|
||||
44:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
45:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 0
|
||||
46:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
47:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 0
|
||||
48:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
49:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
50:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
51:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 0
|
||||
52:
|
||||
action: "HOST_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
53:
|
||||
action: "HOST_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 0
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_name: domain_controller
|
||||
- node_name: web_server
|
||||
- node_name: database_server
|
||||
- node_name: backup_server
|
||||
- node_name: security_suite
|
||||
- node_name: client_1
|
||||
- node_name: client_2
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
ip_list:
|
||||
- 192.168.1.10
|
||||
- 192.168.1.12
|
||||
- 192.168.1.14
|
||||
- 192.168.1.16
|
||||
- 192.168.1.110
|
||||
- 192.168.10.21
|
||||
- 192.168.10.22
|
||||
- 192.168.10.110
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_hostname: web_server
|
||||
service_name: web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- type: router
|
||||
hostname: router_1
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- type: switch
|
||||
hostname: switch_2
|
||||
num_ports: 8
|
||||
|
||||
- type: server
|
||||
hostname: domain_controller
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- type: server
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: WebServer
|
||||
applications:
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
|
||||
|
||||
- type: server
|
||||
hostname: database_server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: DatabaseService
|
||||
|
||||
- type: server
|
||||
hostname: backup_server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- type: FTPServer
|
||||
|
||||
- type: server
|
||||
hostname: security_suite
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
network_interfaces:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- type: computer
|
||||
hostname: client_1
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
port_scan_p_of_success: 0.1
|
||||
data_manipulation_p_of_success: 0.1
|
||||
payload: "DELETE"
|
||||
server_ip: 192.168.1.14
|
||||
services:
|
||||
- type: DNSClient
|
||||
|
||||
- type: computer
|
||||
hostname: client_2
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- type: WebBrowser
|
||||
services:
|
||||
- type: DNSClient
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 8
|
||||
- endpoint_a_hostname: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: switch_2
|
||||
endpoint_b_port: 8
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: web_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_hostname: database_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_hostname: backup_server
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_hostname: security_suite
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: client_1
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_hostname: client_2
|
||||
endpoint_b_port: 1
|
||||
- endpoint_a_hostname: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_hostname: security_suite
|
||||
endpoint_b_port: 2
|
||||
@@ -13,7 +13,6 @@ from primaite.game.agent.interface import AbstractAgent
|
||||
from primaite.game.agent.observations.observation_manager import NestedObservation, ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.session import PrimaiteSession
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.container import Network
|
||||
@@ -121,38 +120,6 @@ def file_system() -> FileSystem:
|
||||
return computer.file_system
|
||||
|
||||
|
||||
# PrimAITE v2 stuff
|
||||
class TempPrimaiteSession(PrimaiteSession):
|
||||
"""
|
||||
A temporary PrimaiteSession class.
|
||||
|
||||
Uses context manager for deletion of files upon exit.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession":
|
||||
"""Create a temporary PrimaiteSession object from a config file."""
|
||||
config_path = Path(config_path)
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
return super().from_config(cfg=config)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type, value, tb):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
|
||||
"""Create a temporary PrimaiteSession object."""
|
||||
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
|
||||
config_path = request.param[0]
|
||||
return TempPrimaiteSession.from_config(config_path=config_path)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def client_server() -> Tuple[Computer, Server]:
|
||||
network = Network()
|
||||
|
||||
@@ -1,33 +1,28 @@
|
||||
import pytest
|
||||
import ray
|
||||
import yaml
|
||||
from ray import air, tune
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
|
||||
from primaite.config.load import data_manipulation_config_path
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.environment import PrimaiteRayMARLEnv
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Slow, reenable later")
|
||||
def test_rllib_multi_agent_compatibility():
|
||||
"""Test that the PrimaiteRayEnv class can be used with a multi agent RLLIB system."""
|
||||
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
with open(MULTI_AGENT_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
game = PrimaiteGame.from_config(cfg)
|
||||
|
||||
ray.shutdown()
|
||||
ray.init()
|
||||
|
||||
env_config = {"game": game}
|
||||
config = (
|
||||
PPOConfig()
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config={"game": game})
|
||||
.environment(env=PrimaiteRayMARLEnv, env_config=cfg)
|
||||
.rollouts(num_rollout_workers=0)
|
||||
.multi_agent(
|
||||
policies={agent.agent_name for agent in game.rl_agents},
|
||||
policies={agent["ref"] for agent in cfg["agents"]},
|
||||
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
|
||||
)
|
||||
.training(train_batch_size=128)
|
||||
|
||||
91
tests/e2e_integration_tests/test_environment.py
Normal file
91
tests/e2e_integration_tests/test_environment.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pydantic
|
||||
import pytest
|
||||
import yaml
|
||||
from gymnasium.core import ObsType
|
||||
from numpy import ndarray
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayMARLEnv
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
||||
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
|
||||
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
|
||||
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
class TestPrimaiteEnvironment:
|
||||
def test_creating_env(self):
|
||||
"""Check that environment loads correctly from config and it can be reset."""
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
|
||||
def env_checks():
|
||||
assert env is not None
|
||||
assert env.game.simulation
|
||||
assert len(env.game.agents) == 3
|
||||
assert len(env.game.rl_agents) == 1
|
||||
|
||||
assert env.game.simulation.network
|
||||
assert len(env.game.simulation.network.nodes) == 12
|
||||
wireless = env.game.simulation.network.get_node_by_hostname("router_2")
|
||||
assert isinstance(wireless, WirelessRouter)
|
||||
printer = env.game.simulation.network.get_node_by_hostname("HP_LaserJet_Pro_4102fdn_printer")
|
||||
assert isinstance(printer, Printer)
|
||||
|
||||
env_checks()
|
||||
env.reset()
|
||||
env_checks()
|
||||
|
||||
def test_step_env(self):
|
||||
"""Make sure you can go all the way through the session without errors."""
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
|
||||
assert (num_actions := len(env.agent.action_manager.action_map)) == 54
|
||||
# run every action and make sure there's no crash
|
||||
for act in range(num_actions):
|
||||
env.step(act)
|
||||
# try running action number outside the action map to check that it fails.
|
||||
with pytest.raises(KeyError):
|
||||
env.step(num_actions)
|
||||
|
||||
obs, rew, trunc, term, info = env.step(0)
|
||||
assert isinstance(obs, ndarray)
|
||||
|
||||
def test_multi_agent_env(self):
|
||||
"""Check that we can run a training session with a multi agent system."""
|
||||
with open(MULTI_AGENT_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteRayMARLEnv(env_config=cfg)
|
||||
|
||||
assert set(env._agent_ids) == {"defender1", "defender2"}
|
||||
|
||||
assert len(env.agents) == 2
|
||||
defender1 = env.agents["defender1"]
|
||||
defender2 = env.agents["defender2"]
|
||||
assert (num_actions_1 := len(defender1.action_manager.action_map)) == 54
|
||||
assert (num_actions_2 := len(defender2.action_manager.action_map)) == 38
|
||||
|
||||
# ensure we can run all valid actions without error
|
||||
for act_1 in range(num_actions_1):
|
||||
env.step({"defender1": act_1, "defender2": 0})
|
||||
for act_2 in range(num_actions_2):
|
||||
env.step({"defender1": 0, "defender2": act_2})
|
||||
|
||||
# ensure we get error when taking an invalid action
|
||||
with pytest.raises(KeyError):
|
||||
env.step({"defender1": num_actions_1, "defender2": 0})
|
||||
with pytest.raises(KeyError):
|
||||
env.step({"defender1": 0, "defender2": num_actions_2})
|
||||
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
"""Make sure we throw an error when the config is bad."""
|
||||
with open(MISCONFIGURED_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
@@ -1,109 +0,0 @@
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from primaite.simulator.network.hardware.nodes.host.server import Printer
|
||||
from primaite.simulator.network.hardware.nodes.network.wireless_router import WirelessRouter
|
||||
from tests import TEST_ASSETS_ROOT
|
||||
from tests.conftest import TempPrimaiteSession
|
||||
|
||||
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
|
||||
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
|
||||
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
|
||||
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
|
||||
MULTI_AGENT_PATH = TEST_ASSETS_ROOT / "configs/multi_agent_session.yaml"
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
class TestPrimaiteSession:
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_creating_session(self, temp_primaite_session):
|
||||
"""Check that creating a session from config works."""
|
||||
with temp_primaite_session as session:
|
||||
if not isinstance(session, TempPrimaiteSession):
|
||||
raise AssertionError
|
||||
|
||||
assert session is not None
|
||||
assert session.env.game.simulation
|
||||
assert len(session.env.game.agents) == 3
|
||||
assert len(session.env.game.rl_agents) == 1
|
||||
|
||||
assert session.policy
|
||||
assert session.env
|
||||
|
||||
assert session.env.game.simulation.network
|
||||
assert len(session.env.game.simulation.network.nodes) == 12
|
||||
wireless = session.env.game.simulation.network.get_node_by_hostname("router_2")
|
||||
assert isinstance(wireless, WirelessRouter)
|
||||
printer = session.env.game.simulation.network.get_node_by_hostname("HP_LaserJet_Pro_4102fdn_printer")
|
||||
assert isinstance(printer, Printer)
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_start_session(self, temp_primaite_session):
|
||||
"""Make sure you can go all the way through the session without errors."""
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
session.start_session()
|
||||
|
||||
session_path = session.io_manager.session_path
|
||||
assert session_path.exists()
|
||||
print(list(session_path.glob("*")))
|
||||
checkpoint_dir = session_path / "checkpoints" / "sb3_final"
|
||||
assert checkpoint_dir.exists()
|
||||
checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip"
|
||||
checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip"
|
||||
checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip"
|
||||
assert checkpoint_1.exists()
|
||||
assert checkpoint_2.exists()
|
||||
assert not checkpoint_3.exists()
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
|
||||
def test_training_only_session(self, temp_primaite_session):
|
||||
"""Check that you can run a training-only session."""
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
session.start_session()
|
||||
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
|
||||
def test_eval_only_session(self, temp_primaite_session):
|
||||
"""Check that you can load a model and run an eval-only session."""
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
session.start_session()
|
||||
# TODO: include checks that the model was loaded and that the eval-only session ran
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.skip(reason="Slow, reenable later")
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[MULTI_AGENT_PATH]], indirect=True)
|
||||
def test_multi_agent_session(self, temp_primaite_session):
|
||||
"""Check that we can run a training session with a multi agent system."""
|
||||
with temp_primaite_session as session:
|
||||
session.start_session()
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
|
||||
|
||||
@pytest.mark.skip(reason="Session is not being maintained and will be removed in the subsequent beta release.")
|
||||
@pytest.mark.skip(
|
||||
reason="Currently software cannot be dynamically created/destroyed during simulation. Therefore, "
|
||||
"reset doesn't implement software restore."
|
||||
)
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_session_sim_reset(self, temp_primaite_session):
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
|
||||
client_1.software_manager.uninstall("DataManipulationBot")
|
||||
|
||||
assert "DataManipulationBot" not in client_1.software_manager.software
|
||||
|
||||
session.game.reset()
|
||||
client_1 = session.game.simulation.network.get_node_by_hostname("client_1")
|
||||
|
||||
assert "DataManipulationBot" in client_1.software_manager.software
|
||||
Reference in New Issue
Block a user