Merge remote-tracking branch 'origin/dev' into feature/1972-remove-sqlite

This commit is contained in:
Marek Wolan
2023-11-18 03:41:22 +00:00
29 changed files with 3556 additions and 387 deletions

View File

@@ -29,17 +29,6 @@ jobs:
pip install -e .[dev]
displayName: 'Install PrimAITE for docs autosummary'
- script: |
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
python -m pip install $GATE_WHEEL[dev]
displayName: 'Install GATE'
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
- script: |
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
displayName: 'Install GATE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
primaite setup
displayName: 'Perform PrimAITE Setup'

View File

@@ -81,17 +81,6 @@ stages:
displayName: 'Install PrimAITE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
python -m pip install $GATE_WHEEL[dev]
displayName: 'Install GATE'
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
- script: |
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
displayName: 'Install GATE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
primaite setup
displayName: 'Perform PrimAITE Setup'

View File

@@ -46,7 +46,6 @@ python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory
.\.venv\Scripts\activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
pip install GATE/arcd_gate-0.1.0-py3-none-any.whl
primaite setup
```
@@ -75,7 +74,6 @@ cd ~/primaite
python3 -m venv .venv
source .venv/bin/activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
pip install arcd_gate-0.1.0-py3-none-any.whl
primaite setup
```
@@ -120,7 +118,6 @@ source venv/bin/activate
```bash
python3 -m pip install -e .[dev]
pip install arcd_gate-0.1.0-py3-none-any.whl
```
#### 6. Perform the PrimAITE setup:

View File

@@ -38,7 +38,7 @@ dependencies = [
"stable-baselines3[extra]==2.1.0",
"tensorflow==2.12.0",
"typer[all]==0.9.0",
"pydantic==2.1.1"
"pydantic==2.1.1",
]
[tool.setuptools.dynamic]

View File

@@ -29,6 +29,15 @@ class _PrimaitePaths:
def __init__(self) -> None:
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
self.user_home_path = self.generate_user_home_path()
self.user_sessions_path = self.generate_user_sessions_path()
self.user_config_path = self.generate_user_config_path()
self.user_notebooks_path = self.generate_user_notebooks_path()
self.app_home_path = self.generate_app_home_path()
self.app_config_dir_path = self.generate_app_config_dir_path()
self.app_config_file_path = self.generate_app_config_file_path()
self.app_log_dir_path = self.generate_app_log_dir_path()
self.app_log_file_path = self.generate_app_log_file_path()
def _get_dirs_properties(self) -> List[str]:
class_items = self.__class__.__dict__.items()
@@ -43,55 +52,47 @@ class _PrimaitePaths:
for p in self._get_dirs_properties():
getattr(self, p)
@property
def user_home_path(self) -> Path:
def generate_user_home_path(self) -> Path:
"""The PrimAITE user home path."""
path = Path.home() / "primaite" / __version__
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_sessions_path(self) -> Path:
def generate_user_sessions_path(self) -> Path:
"""The PrimAITE user sessions path."""
path = self.user_home_path / "sessions"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_config_path(self) -> Path:
def generate_user_config_path(self) -> Path:
"""The PrimAITE user config path."""
path = self.user_home_path / "config"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_notebooks_path(self) -> Path:
def generate_user_notebooks_path(self) -> Path:
"""The PrimAITE user notebooks path."""
path = self.user_home_path / "notebooks"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_home_path(self) -> Path:
def generate_app_home_path(self) -> Path:
"""The PrimAITE app home path."""
path = self._dirs.user_data_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_dir_path(self) -> Path:
def generate_app_config_dir_path(self) -> Path:
"""The PrimAITE app config directory path."""
path = self._dirs.user_config_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_file_path(self) -> Path:
def generate_app_config_file_path(self) -> Path:
"""The PrimAITE app config file path."""
return self.app_config_dir_path / "primaite_config.yaml"
@property
def app_log_dir_path(self) -> Path:
def generate_app_log_dir_path(self) -> Path:
"""The PrimAITE app log directory path."""
if sys.platform == "win32":
path = self.app_home_path / "logs"
@@ -100,8 +101,7 @@ class _PrimaitePaths:
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_log_file_path(self) -> Path:
def generate_app_log_file_path(self) -> Path:
"""The PrimAITE app log file path."""
return self.app_log_dir_path / "primaite.log"
@@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict:
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARN": logging.WARN,
"WARNING": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}

View File

@@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None:
WARNING: All user-data will be lost.
"""
from arcd_gate.cli import setup as gate_setup
from primaite import getLogger
from primaite.setup import reset_demo_notebooks, reset_example_configs
@@ -115,15 +113,13 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Rebuilding the example notebooks...")
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("Setting up ARCD GATE...")
gate_setup()
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
config: Optional[str] = None,
agent_load_file: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
@@ -131,16 +127,10 @@ def session(
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from threading import Thread
from primaite.config.load import example_config_path
from primaite.main import run
from primaite.utils.start_gate_server import start_gate_server
server_thread = Thread(target=start_gate_server)
server_thread.start()
if not config:
config = example_config_path()
print(config)
run(config_path=config)
run(config_path=config, agent_load_path=agent_load_file)

View File

@@ -2,10 +2,17 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 20
n_learn_steps: 128
n_eval_episodes: 20
n_eval_steps: 128
n_learn_episodes: 25
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
game_config:
@@ -108,7 +115,7 @@ game_config:
- ref: defender
team: BLUE
type: GATERLAgent
type: ProxyAgent
observation_space:
type: UC2BlueObservation

View File

@@ -1,31 +0,0 @@
# flake8: noqa
from typing import Dict, Optional, Tuple
from gymnasium.core import ActType, ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractGATEAgent, ObsType
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
class GATERLAgent(AbstractGATEAgent):
...
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
# because when we are supporting MARL, the actions form multiple agents will have to be batched
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
# with the actions for those agents.
def __init__(
self,
agent_name: str | None,
action_space: ActionManager | None,
observation_space: ObservationSpace | None,
reward_function: RewardFunction | None,
) -> None:
super().__init__(agent_name, action_space, observation_space, reward_function)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
return self.most_recent_action

View File

@@ -1,15 +1,13 @@
"""Interface for agents."""
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
from typing import Dict, List, Optional, Tuple
import numpy as np
from gymnasium.core import ActType, ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
ObsType: TypeAlias = Union[Dict, np.ndarray]
class AbstractAgent(ABC):
"""Base class for scripted and RL agents."""
@@ -18,7 +16,7 @@ class AbstractAgent(ABC):
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
) -> None:
"""
@@ -34,24 +32,24 @@ class AbstractAgent(ABC):
:type reward_function: Optional[RewardFunction]
"""
self.agent_name: str = agent_name or "unnamed_agent"
self.action_space: Optional[ActionManager] = action_space
self.observation_space: Optional[ObservationSpace] = observation_space
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
# by for example specifying target ip addresses, or converting a node ID into a uuid
self.execution_definition = None
def convert_state_to_obs(self, state: Dict) -> ObsType:
def update_observation(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
return self.observation_manager.update(state)
def calculate_reward_from_state(self, state: Dict) -> float:
def update_reward(self, state: Dict) -> float:
"""
Use the reward function to calculate a reward from the state.
@@ -60,10 +58,10 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.calculate(state)
return self.reward_function.update(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
@@ -76,7 +74,7 @@ class AbstractAgent(ABC):
:return: Action to be taken in the environment.
:rtype: Tuple[str, Dict]
"""
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {})
@@ -84,7 +82,7 @@ class AbstractAgent(ABC):
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.action_space.form_request(action_identifier=action, action_options=options)
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
@@ -97,7 +95,7 @@ class AbstractScriptedAgent(AbstractAgent):
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
@@ -107,10 +105,44 @@ class RandomAgent(AbstractScriptedAgent):
:return: _description_
:rtype: Tuple[str, Dict]
"""
return self.action_space.get_action(self.action_space.space.sample())
return self.action_manager.get_action(self.action_manager.space.sample())
class AbstractGATEAgent(AbstractAgent):
"""Base class for actors controlled via external messages, such as RL policies."""
class ProxyAgent(AbstractAgent):
"""Agent that sends observations to an RL model and receives actions from that model."""
...
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
:type obs: ObsType
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
:type reward: float, optional
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.most_recent_action)
def store_action(self, action: ActType):
"""
Store the most recent action taken by the agent.
The environment is responsible for calling this method when it receives an action from the agent policy.
"""
self.most_recent_action = action

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation):
pass
class ObservationSpace:
class ObservationManager:
"""
Manage the observations of an Agent.
@@ -947,15 +948,17 @@ class ObservationSpace:
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.current_observation: ObsType
def observe(self, state: Dict) -> Dict:
def update(self, state: Dict) -> Dict:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
"""
return self.obs.observe(state)
self.current_observation = self.obs.observe(state)
return self.current_observation
@property
def space(self) -> None:
@@ -963,7 +966,7 @@ class ObservationSpace:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.

View File

@@ -26,7 +26,7 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward):
class RewardFunction:
"""Manages the reward function for the agent."""
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
__rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
@@ -238,6 +238,7 @@ class RewardFunction:
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.
@@ -249,7 +250,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def calculate(self, state: Dict) -> float:
def update(self, state: Dict) -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
@@ -260,7 +261,8 @@ class RewardFunction:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state)
return total
self.current_reward = total
return self.current_reward
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":

63
src/primaite/game/io.py Normal file
View File

@@ -0,0 +1,63 @@
from datetime import datetime
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, ConfigDict
from primaite import PRIMAITE_PATHS
from primaite.simulator import SIM_OUTPUT
class SessionIOSettings(BaseModel):
"""Schema for session IO settings."""
model_config = ConfigDict(extra="forbid")
save_final_model: bool = True
"""Whether to save the final model right at the end of training."""
save_checkpoints: bool = False
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
checkpoint_interval: int = 10
"""How often to save a checkpoint model (if save_checkpoints is True)."""
save_logs: bool = True
"""Whether to save logs"""
save_transactions: bool = True
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_tensorboard_logs: bool = False
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
class SessionIO:
"""
Class for managing session IO.
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
"""
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
self.settings: SessionIOSettings = settings
self.session_path: Path = self.generate_session_path()
# set global SIM_OUTPUT path
SIM_OUTPUT.path = self.session_path / "simulation_output"
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
# possible refactor needed
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
if timestamp is None:
timestamp = datetime.now()
date_str = timestamp.strftime("%Y-%m-%d")
time_str = timestamp.strftime("%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def generate_model_save_path(self, agent_name: str) -> Path:
"""Return the path where the final model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_final"
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"

View File

@@ -0,0 +1,3 @@
from primaite.game.policy.sb3 import SB3Policy
__all__ = ["SB3Policy"]

View File

@@ -0,0 +1,84 @@
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
"""Base class for reinforcement learning agents."""
_registry: Dict[str, Type["PolicyABC"]] = {}
"""
Registry of policy types, keyed by name.
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register a policy subclass.
:param name: Identifier used by from_config to create an instance of the policy.
:type name: str
:raises ValueError: When attempting to create a policy with a duplicate name.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate policy name {identifier}")
cls._registry[identifier] = cls
return
@abstractmethod
def __init__(self, session: "PrimaiteSession") -> None:
"""
Initialize a reinforcement learning policy.
:param session: The session context.
:type session: PrimaiteSession
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteSession" = session
"""Reference to the session."""
@abstractmethod
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
pass
@abstractmethod
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
"""Evaluate the agent."""
pass
@abstractmethod
def save(self, save_path: Path) -> None:
"""Save the agent."""
pass
@abstractmethod
def load(self) -> None:
"""Load agent from a file."""
pass
def close(self) -> None:
"""Close the agent."""
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.
Subclasses should not call super().from_config(), they should just handle creation form config.
"""
# Assume that basically the contents of training_config are passed into here.
# I should really define a config schema class using pydantic.
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass

View File

@@ -0,0 +1,80 @@
"""Stable baselines 3 policy."""
from pathlib import Path
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
self._agent_class: Type[Union[PPO, A2C]]
if algorithm == "PPO":
self._agent_class = PPO
policy = PPO_MLP
elif algorithm == "A2C":
self._agent_class = A2C
policy = A2C_MLP
else:
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
self._agent = self._agent_class(
policy=policy,
env=self.session.env,
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
seed=seed,
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
if self.session.io_manager.settings.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)
else:
checkpoint_callback = None
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
reward_data = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
print(reward_data)
def save(self, save_path: Path) -> None:
"""
Save the current policy parameters.
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
Therefore, this method is only used to save the final model.
"""
self._agent.save(save_path)
def load(self, model_path: Path) -> None:
"""Load agent from a checkpoint."""
self._agent = self._agent_class.load(model_path, env=self.session.env)
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
"""Create an agent from config file."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,18 +1,20 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
from arcd_gate.client.gate_client import ActType, GATEClient
from gymnasium import spaces
import gymnasium
from gymnasium.core import ActType, ObsType
from gymnasium.spaces.utils import flatten, flatten_space
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, RandomAgent
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -34,109 +36,62 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteGATEClient(GATEClient):
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
"""
Create a new GATE client for PrimAITE.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
:param parent_session: The parent session object.
:type parent_session: PrimaiteSession
:param service_port: The port on which the GATE service is running.
:type service_port: int, optional
"""
super().__init__(service_port=service_port)
self.parent_session: "PrimaiteSession" = parent_session
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
@property
def rl_framework(self) -> str:
"""The reinforcement learning framework to use."""
return self.parent_session.training_options.rl_framework
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
@property
def rl_algorithm(self) -> str:
"""The reinforcement learning algorithm to use."""
return self.parent_session.training_options.rl_algorithm
@property
def seed(self) -> int | None:
"""The seed to use for the environment's random number generator."""
return self.parent_session.training_options.seed
@property
def n_learn_episodes(self) -> int:
"""The number of episodes in each learning run."""
return self.parent_session.training_options.n_learn_episodes
@property
def n_learn_steps(self) -> int:
"""The number of steps in each learning episode."""
return self.parent_session.training_options.n_learn_steps
@property
def n_eval_episodes(self) -> int:
"""The number of episodes in each evaluation run."""
return self.parent_session.training_options.n_eval_episodes
@property
def n_eval_steps(self) -> int:
"""The number of steps in each evaluation episode."""
return self.parent_session.training_options.n_eval_steps
@property
def action_space(self) -> spaces.Space:
"""The gym action space of the agent."""
return self.parent_session.rl_agent.action_space.space
@property
def observation_space(self) -> spaces.Space:
"""The gymnasium observation space of the agent."""
return flatten_space(self.parent_session.rl_agent.observation_space.space)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
"""Take a step in the environment.
This method is called by GATE to advance the simulation by one timestep.
:param action: The agent's action.
:type action: ActType
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
:rtype: Tuple[ObsType, float, bool, bool, Dict]
"""
self.parent_session.rl_agent.most_recent_action = action
self.parent_session.step()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
rew = self.parent_session.rl_agent.reward_function.calculate(state)
term = False
trunc = False
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.session.calculate_truncated()
info = {}
return obs, rew, term, trunc, info
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
"""Reset the environment.
return next_obs, reward, terminated, truncated, info
This method is called when the environment is initialized and at the end of each episode.
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
:param seed: The seed to use for the environment's random number generator.
:type seed: int, optional
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
compatibility with GATE.
:type options: dict[str, Any], optional
:return: The initial observation and an empty info dictionary.
:rtype: Tuple[ObsType, Dict]
"""
self.parent_session.reset()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
return obs, {}
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
def close(self):
"""Close the session, this will stop the gate client and close the simulation."""
self.parent_session.close()
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteSessionOptions(BaseModel):
@@ -146,6 +101,8 @@ class PrimaiteSessionOptions(BaseModel):
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
"""
model_config = ConfigDict(extra="forbid")
ports: List[str]
protocols: List[str]
@@ -153,48 +110,102 @@ class PrimaiteSessionOptions(BaseModel):
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
rl_framework: str
rl_algorithm: str
seed: Optional[int]
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_learn_steps: int
n_eval_episodes: int
n_eval_steps: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
def __init__(self):
"""Initialise a PrimaiteSession object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.rl_agent: AbstractAgent
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
self.gate_client.start()
"""Commence the training session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
def step(self):
"""
@@ -208,57 +219,76 @@ class PrimaiteSession:
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
interacts with should implement a step method that calls methods used by this method. For example, if using a
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
# currently designed with assumption that all agents act once per step in order
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent in self.agents:
# 3. primaite session asks simulation to provide initial state
# 4. primate session gives state to all agents
# 5. primaite session asks agents to produce an action based on most recent state
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
sim_state = self.simulation.describe_state()
agent.update_observation(state)
agent.update_reward(state)
# 6. each agent takes most recent state and converts it to CAOS observation
agent_obs = agent.convert_state_to_obs(sim_state)
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
# 7. meanwhile each agent also takes state and calculates reward
agent_reward = agent.calculate_reward_from_state(sim_state)
# 8. each agent takes observation and applies decision rule to observation to create CAOS
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
# code should live inside of the GATE agent subclass)
# gets action in CAOS format
_LOGGER.debug("Getting agent action")
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
# 9. CAOS action is converted into request (extra information might be needed to enrich
# the request, this is what the execution definition is there for)
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
agent_request = agent.format_request(agent_action, action_options)
# 10. primaite session receives the action from the agents and asks the simulation to apply each
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
self.simulation.apply_request(agent_request)
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
self.simulation.apply_timestep(self.step_counter)
def advance_timestep(self) -> None:
"""Advance timestep."""
self.step_counter += 1
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
self.simulation.apply_timestep(self.step_counter)
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
max_steps = self.training_options.max_steps_per_episode
if current_step >= max_steps:
return True
return False
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
return NotImplemented
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation.reset_component_for_episode(self.episode_counter)
def close(self) -> None:
"""Close the session, this will stop the gate client and close the simulation."""
"""Close the session, this will stop the env and close the simulation."""
return NotImplemented
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent. Used by GATE.
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteSession.
3. simulation: defines the network topology and the initial state of the simulation.
@@ -276,6 +306,11 @@ class PrimaiteSession:
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
sim = sess.simulation
net = sim.network
@@ -412,7 +447,7 @@ class PrimaiteSession:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
@@ -448,15 +483,15 @@ class PrimaiteSession:
reward_function=rew_function,
)
sess.agents.append(new_agent)
elif agent_type == "GATERLAgent":
new_agent = RandomAgent(
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
)
sess.agents.append(new_agent)
sess.rl_agent = new_agent
sess.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = RandomAgent(
agent_name=agent_cfg["ref"],
@@ -468,4 +503,12 @@ class PrimaiteSession:
else:
print("agent type not found")
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess

View File

@@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__)
def run(
config_path: Optional[Union[str, Path]] = "",
agent_load_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
@@ -31,7 +32,7 @@ def run(
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()

View File

@@ -1,14 +1,26 @@
"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory."""
from datetime import datetime
from pathlib import Path
from primaite import _PRIMAITE_ROOT
SIM_OUTPUT = None
"A path at the repo root dir to use temporarily for sim output testing while in dev."
# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path
__all__ = ["SIM_OUTPUT"]
if not SIM_OUTPUT:
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path
SIM_OUTPUT.mkdir(exist_ok=True, parents=True)
class __SimOutput:
def __init__(self):
self._path: Path = (
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
@property
def path(self) -> Path:
return self._path
@path.setter
def path(self, new_path: Path) -> None:
self._path = new_path
self._path.mkdir(exist_ok=True, parents=True)
SIM_OUTPUT = __SimOutput()

View File

@@ -957,7 +957,7 @@ class Node(SimComponent):
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT / kwargs["hostname"]
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):

View File

@@ -75,7 +75,7 @@ class PacketCapture:
def _get_log_path(self) -> Path:
"""Get the path for the log file."""
root = SIM_OUTPUT / self.hostname
root = SIM_OUTPUT.path / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self._logger_name}.log"

View File

@@ -41,7 +41,6 @@ class SysLog:
JSON-like messages.
"""
log_path = self._get_log_path()
file_handler = logging.FileHandler(filename=log_path)
file_handler.setLevel(logging.DEBUG)
@@ -81,7 +80,7 @@ class SysLog:
:return: Path object representing the location of the log file.
"""
root = SIM_OUTPUT / self.hostname
root = SIM_OUTPUT.path / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"

View File

@@ -1,12 +0,0 @@
"""Utility script to start the gate server for running PrimAITE in attached mode."""
from arcd_gate.server.gate_service import GATEService
def start_gate_server():
"""Start the gate server."""
service = GATEService()
service.start()
if __name__ == "__main__":
start_gate_server()

View File

@@ -0,0 +1,725 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
se3ed: 333 # Purposeful typo to check that error is raised with bad configuration.
n_learn_steps: 2560
n_eval_episodes: 5
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -0,0 +1,729 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 0
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -0,0 +1,733 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 10
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -0,0 +1,729 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 10
n_eval_episodes: 0
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -5,12 +5,13 @@ import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
from unittest.mock import patch
import nodeenv
import pytest
import yaml
from primaite import getLogger
from primaite.game.session import PrimaiteSession
# from primaite.environment.primaite_env import Primaite
# from primaite.primaite_session import PrimaiteSession
@@ -20,13 +21,15 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.service import Service
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
from primaite import PRIMAITE_PATHS
# PrimAITE v3 stuff
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.base import Node
@@ -71,102 +74,32 @@ def file_system() -> FileSystem:
# PrimAITE v2 stuff
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
class TempPrimaiteSession: # PrimaiteSession):
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.
Uses context manager for deletion of files upon exit.
"""
# def __init__(
# self,
# training_config_path: Union[str, Path],
# lay_down_config_path: Union[str, Path],
# ):
# super().__init__(training_config_path, lay_down_config_path)
# self.setup()
@classmethod
def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession":
"""Create a temporary PrimaiteSession object from a config file."""
config_path = Path(config_path)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# @property
# def env(self) -> Primaite:
# """Direct access to the env for ease of testing."""
# return self._agent_session._env # noqa
return super().from_config(cfg=config)
# def __enter__(self):
# return self
def __enter__(self):
return self
# def __exit__(self, type, value, tb):
# shutil.rmtree(self.session_path)
# _LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
def __exit__(self, type, value, tb):
pass
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
@pytest.fixture
def temp_primaite_session(request):
"""
Provides a temporary PrimaiteSession instance.
It's temporary as it uses a temporary directory as the session path.
To use this fixture you need to:
- parametrize your test function with:
- "temp_primaite_session"
- [[path to training config, path to lay down config]]
- Include the temp_primaite_session fixture as a param in your test
function.
- use the temp_primaite_session as a context manager assigning is the
name 'session'.
.. code:: python
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
@pytest.mark.parametrize(
"temp_primaite_session",
[
[main_training_config_path(), dos_very_basic_config_path()]
],
indirect=True
)
def test_primaite_session(temp_primaite_session):
with temp_primaite_session as session:
# Learning outputs are saved in session.learning_path
session.learn()
# Evaluation outputs are saved in session.evaluation_path
session.evaluate()
# To ensure that all files are written, you must call .close()
session.close()
# If you need to inspect any session outputs, it must be done
# inside the context manager
# Now that we've exited the context manager, the
# session.session_path directory and its contents are deleted
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
@pytest.fixture
def temp_session_path() -> Path:
"""
Get a temp directory session path the test session will output to.
:return: The session directory path.
"""
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "_primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
"""Create a temporary PrimaiteSession object."""
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
config_path = request.param[0]
return TempPrimaiteSession.from_config(config_path=config_path)

View File

@@ -0,0 +1,68 @@
import pydantic
import pytest
from tests.conftest import TempPrimaiteSession
CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
class TestPrimaiteSession:
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_creating_session(self, temp_primaite_session):
"""Check that creating a session from config works."""
with temp_primaite_session as session:
if not isinstance(session, TempPrimaiteSession):
raise AssertionError
assert session is not None
assert session.simulation
assert len(session.agents) == 3
assert len(session.rl_agents) == 1
assert session.policy
assert session.env
assert session.simulation.network
assert len(session.simulation.network.nodes) == 10
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):
"""Make sure you can go all the way through the session without errors."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
session_path = session.io_manager.session_path
assert session_path.exists()
print(list(session_path.glob("*")))
checkpoint_dir = session_path / "checkpoints" / "sb3_final"
assert checkpoint_dir.exists()
checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip"
checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip"
checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip"
assert checkpoint_1.exists()
assert checkpoint_2.exists()
assert not checkpoint_3.exists()
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
def test_training_only_session(self, temp_primaite_session):
"""Check that you can run a training-only session."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
def test_eval_only_session(self, temp_primaite_session):
"""Check that you can load a model and run an eval-only session."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
# TODO: include checks that the model was loaded and that the eval-only session ran
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)

View File

@@ -9,7 +9,7 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
def get_temp_session_path(session_timestamp: datetime) -> Path:
def temp_user_sessions_path() -> Path:
"""
Get a temp directory session path the test session will output to.