Merge remote-tracking branch 'origin/dev' into feature/1972-remove-sqlite
This commit is contained in:
@@ -29,17 +29,6 @@ jobs:
|
||||
pip install -e .[dev]
|
||||
displayName: 'Install PrimAITE for docs autosummary'
|
||||
|
||||
- script: |
|
||||
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
|
||||
python -m pip install $GATE_WHEEL[dev]
|
||||
displayName: 'Install GATE'
|
||||
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
|
||||
|
||||
- script: |
|
||||
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
|
||||
displayName: 'Install GATE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
- script: |
|
||||
primaite setup
|
||||
displayName: 'Perform PrimAITE Setup'
|
||||
|
||||
@@ -81,17 +81,6 @@ stages:
|
||||
displayName: 'Install PrimAITE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
- script: |
|
||||
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
|
||||
python -m pip install $GATE_WHEEL[dev]
|
||||
displayName: 'Install GATE'
|
||||
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
|
||||
|
||||
- script: |
|
||||
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
|
||||
displayName: 'Install GATE'
|
||||
condition: eq( variables['Agent.OS'], 'Windows_NT' )
|
||||
|
||||
- script: |
|
||||
primaite setup
|
||||
displayName: 'Perform PrimAITE Setup'
|
||||
|
||||
@@ -46,7 +46,6 @@ python3 -m venv .venv
|
||||
attrib +h .venv /s /d # Hides the .venv directory
|
||||
.\.venv\Scripts\activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install GATE/arcd_gate-0.1.0-py3-none-any.whl
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -75,7 +74,6 @@ cd ~/primaite
|
||||
python3 -m venv .venv
|
||||
source .venv/bin/activate
|
||||
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
|
||||
pip install arcd_gate-0.1.0-py3-none-any.whl
|
||||
primaite setup
|
||||
```
|
||||
|
||||
@@ -120,7 +118,6 @@ source venv/bin/activate
|
||||
|
||||
```bash
|
||||
python3 -m pip install -e .[dev]
|
||||
pip install arcd_gate-0.1.0-py3-none-any.whl
|
||||
```
|
||||
|
||||
#### 6. Perform the PrimAITE setup:
|
||||
|
||||
@@ -38,7 +38,7 @@ dependencies = [
|
||||
"stable-baselines3[extra]==2.1.0",
|
||||
"tensorflow==2.12.0",
|
||||
"typer[all]==0.9.0",
|
||||
"pydantic==2.1.1"
|
||||
"pydantic==2.1.1",
|
||||
]
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
|
||||
@@ -29,6 +29,15 @@ class _PrimaitePaths:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
|
||||
self.user_home_path = self.generate_user_home_path()
|
||||
self.user_sessions_path = self.generate_user_sessions_path()
|
||||
self.user_config_path = self.generate_user_config_path()
|
||||
self.user_notebooks_path = self.generate_user_notebooks_path()
|
||||
self.app_home_path = self.generate_app_home_path()
|
||||
self.app_config_dir_path = self.generate_app_config_dir_path()
|
||||
self.app_config_file_path = self.generate_app_config_file_path()
|
||||
self.app_log_dir_path = self.generate_app_log_dir_path()
|
||||
self.app_log_file_path = self.generate_app_log_file_path()
|
||||
|
||||
def _get_dirs_properties(self) -> List[str]:
|
||||
class_items = self.__class__.__dict__.items()
|
||||
@@ -43,55 +52,47 @@ class _PrimaitePaths:
|
||||
for p in self._get_dirs_properties():
|
||||
getattr(self, p)
|
||||
|
||||
@property
|
||||
def user_home_path(self) -> Path:
|
||||
def generate_user_home_path(self) -> Path:
|
||||
"""The PrimAITE user home path."""
|
||||
path = Path.home() / "primaite" / __version__
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_sessions_path(self) -> Path:
|
||||
def generate_user_sessions_path(self) -> Path:
|
||||
"""The PrimAITE user sessions path."""
|
||||
path = self.user_home_path / "sessions"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_config_path(self) -> Path:
|
||||
def generate_user_config_path(self) -> Path:
|
||||
"""The PrimAITE user config path."""
|
||||
path = self.user_home_path / "config"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def user_notebooks_path(self) -> Path:
|
||||
def generate_user_notebooks_path(self) -> Path:
|
||||
"""The PrimAITE user notebooks path."""
|
||||
path = self.user_home_path / "notebooks"
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_home_path(self) -> Path:
|
||||
def generate_app_home_path(self) -> Path:
|
||||
"""The PrimAITE app home path."""
|
||||
path = self._dirs.user_data_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_config_dir_path(self) -> Path:
|
||||
def generate_app_config_dir_path(self) -> Path:
|
||||
"""The PrimAITE app config directory path."""
|
||||
path = self._dirs.user_config_path
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_config_file_path(self) -> Path:
|
||||
def generate_app_config_file_path(self) -> Path:
|
||||
"""The PrimAITE app config file path."""
|
||||
return self.app_config_dir_path / "primaite_config.yaml"
|
||||
|
||||
@property
|
||||
def app_log_dir_path(self) -> Path:
|
||||
def generate_app_log_dir_path(self) -> Path:
|
||||
"""The PrimAITE app log directory path."""
|
||||
if sys.platform == "win32":
|
||||
path = self.app_home_path / "logs"
|
||||
@@ -100,8 +101,7 @@ class _PrimaitePaths:
|
||||
path.mkdir(exist_ok=True, parents=True)
|
||||
return path
|
||||
|
||||
@property
|
||||
def app_log_file_path(self) -> Path:
|
||||
def generate_app_log_file_path(self) -> Path:
|
||||
"""The PrimAITE app log file path."""
|
||||
return self.app_log_dir_path / "primaite.log"
|
||||
|
||||
@@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict:
|
||||
"DEBUG": logging.DEBUG,
|
||||
"INFO": logging.INFO,
|
||||
"WARN": logging.WARN,
|
||||
"WARNING": logging.WARN,
|
||||
"ERROR": logging.ERROR,
|
||||
"CRITICAL": logging.CRITICAL,
|
||||
}
|
||||
|
||||
@@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
|
||||
WARNING: All user-data will be lost.
|
||||
"""
|
||||
from arcd_gate.cli import setup as gate_setup
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.setup import reset_demo_notebooks, reset_example_configs
|
||||
|
||||
@@ -115,15 +113,13 @@ def setup(overwrite_existing: bool = True) -> None:
|
||||
_LOGGER.info("Rebuilding the example notebooks...")
|
||||
reset_example_configs.run(overwrite_existing=True)
|
||||
|
||||
_LOGGER.info("Setting up ARCD GATE...")
|
||||
gate_setup()
|
||||
|
||||
_LOGGER.info("PrimAITE setup complete!")
|
||||
|
||||
|
||||
@app.command()
|
||||
def session(
|
||||
config: Optional[str] = None,
|
||||
agent_load_file: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run a PrimAITE session.
|
||||
@@ -131,16 +127,10 @@ def session(
|
||||
:param config: The path to the config file. Optional, if None, the example config will be used.
|
||||
:type config: Optional[str]
|
||||
"""
|
||||
from threading import Thread
|
||||
|
||||
from primaite.config.load import example_config_path
|
||||
from primaite.main import run
|
||||
from primaite.utils.start_gate_server import start_gate_server
|
||||
|
||||
server_thread = Thread(target=start_gate_server)
|
||||
server_thread.start()
|
||||
|
||||
if not config:
|
||||
config = example_config_path()
|
||||
print(config)
|
||||
run(config_path=config)
|
||||
run(config_path=config, agent_load_path=agent_load_file)
|
||||
|
||||
@@ -2,10 +2,17 @@ training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 20
|
||||
n_learn_steps: 128
|
||||
n_eval_episodes: 20
|
||||
n_eval_steps: 128
|
||||
n_learn_episodes: 25
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
|
||||
|
||||
game_config:
|
||||
@@ -108,7 +115,7 @@ game_config:
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: GATERLAgent
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
# flake8: noqa
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractGATEAgent, ObsType
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
|
||||
class GATERLAgent(AbstractGATEAgent):
|
||||
...
|
||||
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
|
||||
# because when we are supporting MARL, the actions form multiple agents will have to be batched
|
||||
|
||||
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
|
||||
# with the actions for those agents.
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: str | None,
|
||||
action_space: ActionManager | None,
|
||||
observation_space: ObservationSpace | None,
|
||||
reward_function: RewardFunction | None,
|
||||
) -> None:
|
||||
super().__init__(agent_name, action_space, observation_space, reward_function)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
return self.most_recent_action
|
||||
@@ -1,15 +1,13 @@
|
||||
"""Interface for agents."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, TypeAlias, Union
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from gymnasium.core import ActType, ObsType
|
||||
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
|
||||
ObsType: TypeAlias = Union[Dict, np.ndarray]
|
||||
|
||||
|
||||
class AbstractAgent(ABC):
|
||||
"""Base class for scripted and RL agents."""
|
||||
@@ -18,7 +16,7 @@ class AbstractAgent(ABC):
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationSpace],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
"""
|
||||
@@ -34,24 +32,24 @@ class AbstractAgent(ABC):
|
||||
:type reward_function: Optional[RewardFunction]
|
||||
"""
|
||||
self.agent_name: str = agent_name or "unnamed_agent"
|
||||
self.action_space: Optional[ActionManager] = action_space
|
||||
self.observation_space: Optional[ObservationSpace] = observation_space
|
||||
self.action_manager: Optional[ActionManager] = action_space
|
||||
self.observation_manager: Optional[ObservationManager] = observation_space
|
||||
self.reward_function: Optional[RewardFunction] = reward_function
|
||||
|
||||
# exection definiton converts CAOS action to Primaite simulator request, sometimes having to enrich the info
|
||||
# by for example specifying target ip addresses, or converting a node ID into a uuid
|
||||
self.execution_definition = None
|
||||
|
||||
def convert_state_to_obs(self, state: Dict) -> ObsType:
|
||||
def update_observation(self, state: Dict) -> ObsType:
|
||||
"""
|
||||
Convert a state from the simulator into an observation for the agent using the observation space.
|
||||
|
||||
state : dict state directly from simulation.describe_state
|
||||
output : dict state according to CAOS.
|
||||
"""
|
||||
return self.observation_space.observe(state)
|
||||
return self.observation_manager.update(state)
|
||||
|
||||
def calculate_reward_from_state(self, state: Dict) -> float:
|
||||
def update_reward(self, state: Dict) -> float:
|
||||
"""
|
||||
Use the reward function to calculate a reward from the state.
|
||||
|
||||
@@ -60,10 +58,10 @@ class AbstractAgent(ABC):
|
||||
:return: Reward from the state.
|
||||
:rtype: float
|
||||
"""
|
||||
return self.reward_function.calculate(state)
|
||||
return self.reward_function.update(state)
|
||||
|
||||
@abstractmethod
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return an action to be taken in the environment.
|
||||
|
||||
@@ -76,7 +74,7 @@ class AbstractAgent(ABC):
|
||||
:return: Action to be taken in the environment.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
|
||||
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
|
||||
# then use a bespoke conversion to take 1-40 int back into CAOS action
|
||||
return ("DO_NOTHING", {})
|
||||
|
||||
@@ -84,7 +82,7 @@ class AbstractAgent(ABC):
|
||||
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
|
||||
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
|
||||
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
|
||||
request = self.action_space.form_request(action_identifier=action, action_options=options)
|
||||
request = self.action_manager.form_request(action_identifier=action, action_options=options)
|
||||
return request
|
||||
|
||||
|
||||
@@ -97,7 +95,7 @@ class AbstractScriptedAgent(AbstractAgent):
|
||||
class RandomAgent(AbstractScriptedAgent):
|
||||
"""Agent that ignores its observation and acts completely at random."""
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""Randomly sample an action from the action space.
|
||||
|
||||
:param obs: _description_
|
||||
@@ -107,10 +105,44 @@ class RandomAgent(AbstractScriptedAgent):
|
||||
:return: _description_
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_space.get_action(self.action_space.space.sample())
|
||||
return self.action_manager.get_action(self.action_manager.space.sample())
|
||||
|
||||
|
||||
class AbstractGATEAgent(AbstractAgent):
|
||||
"""Base class for actors controlled via external messages, such as RL policies."""
|
||||
class ProxyAgent(AbstractAgent):
|
||||
"""Agent that sends observations to an RL model and receives actions from that model."""
|
||||
|
||||
...
|
||||
def __init__(
|
||||
self,
|
||||
agent_name: Optional[str],
|
||||
action_space: Optional[ActionManager],
|
||||
observation_space: Optional[ObservationManager],
|
||||
reward_function: Optional[RewardFunction],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
agent_name=agent_name,
|
||||
action_space=action_space,
|
||||
observation_space=observation_space,
|
||||
reward_function=reward_function,
|
||||
)
|
||||
self.most_recent_action: ActType
|
||||
|
||||
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
|
||||
"""
|
||||
Return the agent's most recent action, formatted in CAOS format.
|
||||
|
||||
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
|
||||
:type obs: ObsType
|
||||
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
|
||||
:type reward: float, optional
|
||||
:return: Action to be taken in CAOS format.
|
||||
:rtype: Tuple[str, Dict]
|
||||
"""
|
||||
return self.action_manager.get_action(self.most_recent_action)
|
||||
|
||||
def store_action(self, action: ActType):
|
||||
"""
|
||||
Store the most recent action taken by the agent.
|
||||
|
||||
The environment is responsible for calling this method when it receives an action from the agent policy.
|
||||
"""
|
||||
self.most_recent_action = action
|
||||
|
||||
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from gymnasium import spaces
|
||||
from gymnasium.core import ObsType
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
@@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation):
|
||||
pass
|
||||
|
||||
|
||||
class ObservationSpace:
|
||||
class ObservationManager:
|
||||
"""
|
||||
Manage the observations of an Agent.
|
||||
|
||||
@@ -947,15 +948,17 @@ class ObservationSpace:
|
||||
:type observation: AbstractObservation
|
||||
"""
|
||||
self.obs: AbstractObservation = observation
|
||||
self.current_observation: ObsType
|
||||
|
||||
def observe(self, state: Dict) -> Dict:
|
||||
def update(self, state: Dict) -> Dict:
|
||||
"""
|
||||
Generate observation based on the current state of the simulation.
|
||||
|
||||
:param state: Simulation state dictionary
|
||||
:type state: Dict
|
||||
"""
|
||||
return self.obs.observe(state)
|
||||
self.current_observation = self.obs.observe(state)
|
||||
return self.current_observation
|
||||
|
||||
@property
|
||||
def space(self) -> None:
|
||||
@@ -963,7 +966,7 @@ class ObservationSpace:
|
||||
return self.obs.space
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
|
||||
"""Create observation space from a config.
|
||||
|
||||
:param config: Dictionary containing the configuration for this observation space.
|
||||
|
||||
@@ -26,7 +26,7 @@ the structure:
|
||||
```
|
||||
"""
|
||||
from abc import abstractmethod
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
|
||||
@@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward):
|
||||
class RewardFunction:
|
||||
"""Manages the reward function for the agent."""
|
||||
|
||||
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
|
||||
__rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
|
||||
"DUMMY": DummyReward,
|
||||
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
|
||||
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
|
||||
@@ -238,6 +238,7 @@ class RewardFunction:
|
||||
"""Initialise the reward function object."""
|
||||
self.reward_components: List[Tuple[AbstractReward, float]] = []
|
||||
"attribute reward_components keeps track of reward components and the weights assigned to each."
|
||||
self.current_reward: float
|
||||
|
||||
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
|
||||
"""Add a reward component to the reward function.
|
||||
@@ -249,7 +250,7 @@ class RewardFunction:
|
||||
"""
|
||||
self.reward_components.append((component, weight))
|
||||
|
||||
def calculate(self, state: Dict) -> float:
|
||||
def update(self, state: Dict) -> float:
|
||||
"""Calculate the overall reward for the current state.
|
||||
|
||||
:param state: The current state of the simulation.
|
||||
@@ -260,7 +261,8 @@ class RewardFunction:
|
||||
comp = comp_and_weight[0]
|
||||
weight = comp_and_weight[1]
|
||||
total += weight * comp.calculate(state=state)
|
||||
return total
|
||||
self.current_reward = total
|
||||
return self.current_reward
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":
|
||||
|
||||
63
src/primaite/game/io.py
Normal file
63
src/primaite/game/io.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
|
||||
class SessionIOSettings(BaseModel):
|
||||
"""Schema for session IO settings."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
save_final_model: bool = True
|
||||
"""Whether to save the final model right at the end of training."""
|
||||
save_checkpoints: bool = False
|
||||
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
|
||||
checkpoint_interval: int = 10
|
||||
"""How often to save a checkpoint model (if save_checkpoints is True)."""
|
||||
save_logs: bool = True
|
||||
"""Whether to save logs"""
|
||||
save_transactions: bool = True
|
||||
"""Whether to save transactions, If true, the session path will have a transactions folder."""
|
||||
save_tensorboard_logs: bool = False
|
||||
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
|
||||
|
||||
|
||||
class SessionIO:
|
||||
"""
|
||||
Class for managing session IO.
|
||||
|
||||
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
|
||||
"""
|
||||
|
||||
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
|
||||
self.settings: SessionIOSettings = settings
|
||||
self.session_path: Path = self.generate_session_path()
|
||||
|
||||
# set global SIM_OUTPUT path
|
||||
SIM_OUTPUT.path = self.session_path / "simulation_output"
|
||||
|
||||
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
|
||||
# possible refactor needed
|
||||
|
||||
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
|
||||
"""Create a folder for the session and return the path to it."""
|
||||
if timestamp is None:
|
||||
timestamp = datetime.now()
|
||||
date_str = timestamp.strftime("%Y-%m-%d")
|
||||
time_str = timestamp.strftime("%H-%M-%S")
|
||||
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
return session_path
|
||||
|
||||
def generate_model_save_path(self, agent_name: str) -> Path:
|
||||
"""Return the path where the final model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_final"
|
||||
|
||||
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
|
||||
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
|
||||
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"
|
||||
3
src/primaite/game/policy/__init__.py
Normal file
3
src/primaite/game/policy/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from primaite.game.policy.sb3 import SB3Policy
|
||||
|
||||
__all__ = ["SB3Policy"]
|
||||
84
src/primaite/game/policy/policy.py
Normal file
84
src/primaite/game/policy/policy.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Base class and common logic for RL policies."""
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Type, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class PolicyABC(ABC):
|
||||
"""Base class for reinforcement learning agents."""
|
||||
|
||||
_registry: Dict[str, Type["PolicyABC"]] = {}
|
||||
"""
|
||||
Registry of policy types, keyed by name.
|
||||
|
||||
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
|
||||
"""
|
||||
|
||||
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
|
||||
"""
|
||||
Register a policy subclass.
|
||||
|
||||
:param name: Identifier used by from_config to create an instance of the policy.
|
||||
:type name: str
|
||||
:raises ValueError: When attempting to create a policy with a duplicate name.
|
||||
"""
|
||||
super().__init_subclass__(**kwargs)
|
||||
if identifier in cls._registry:
|
||||
raise ValueError(f"Duplicate policy name {identifier}")
|
||||
cls._registry[identifier] = cls
|
||||
return
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, session: "PrimaiteSession") -> None:
|
||||
"""
|
||||
Initialize a reinforcement learning policy.
|
||||
|
||||
:param session: The session context.
|
||||
:type session: PrimaiteSession
|
||||
:param agents: The agents to train.
|
||||
:type agents: List[RLAgent]
|
||||
"""
|
||||
self.session: "PrimaiteSession" = session
|
||||
"""Reference to the session."""
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""Save the agent."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self) -> None:
|
||||
"""Load agent from a file."""
|
||||
pass
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the agent."""
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
|
||||
"""
|
||||
Create an RL policy from a config by calling the relevant subclass's from_config method.
|
||||
|
||||
Subclasses should not call super().from_config(), they should just handle creation form config.
|
||||
"""
|
||||
# Assume that basically the contents of training_config are passed into here.
|
||||
# I should really define a config schema class using pydantic.
|
||||
|
||||
PolicyType = cls._registry[config.rl_framework]
|
||||
return PolicyType.from_config(config=config, session=session)
|
||||
|
||||
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass
|
||||
80
src/primaite/game/policy/sb3.py
Normal file
80
src/primaite/game/policy/sb3.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Stable baselines 3 policy."""
|
||||
from pathlib import Path
|
||||
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
|
||||
|
||||
from stable_baselines3 import A2C, PPO
|
||||
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
|
||||
from stable_baselines3.common.callbacks import CheckpointCallback
|
||||
from stable_baselines3.common.evaluation import evaluate_policy
|
||||
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
|
||||
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from primaite.game.session import PrimaiteSession, TrainingOptions
|
||||
|
||||
|
||||
class SB3Policy(PolicyABC, identifier="SB3"):
|
||||
"""Single agent RL policy using stable baselines 3."""
|
||||
|
||||
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
|
||||
"""Initialize a stable baselines 3 policy."""
|
||||
super().__init__(session=session)
|
||||
|
||||
self._agent_class: Type[Union[PPO, A2C]]
|
||||
if algorithm == "PPO":
|
||||
self._agent_class = PPO
|
||||
policy = PPO_MLP
|
||||
elif algorithm == "A2C":
|
||||
self._agent_class = A2C
|
||||
policy = A2C_MLP
|
||||
else:
|
||||
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
|
||||
self._agent = self._agent_class(
|
||||
policy=policy,
|
||||
env=self.session.env,
|
||||
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
|
||||
"""Train the agent."""
|
||||
if self.session.io_manager.settings.save_checkpoints:
|
||||
checkpoint_callback = CheckpointCallback(
|
||||
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
|
||||
save_path=self.session.io_manager.generate_model_save_path("sb3"),
|
||||
name_prefix="sb3_model",
|
||||
)
|
||||
else:
|
||||
checkpoint_callback = None
|
||||
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
|
||||
|
||||
def eval(self, n_episodes: int, deterministic: bool) -> None:
|
||||
"""Evaluate the agent."""
|
||||
reward_data = evaluate_policy(
|
||||
self._agent,
|
||||
self.session.env,
|
||||
n_eval_episodes=n_episodes,
|
||||
deterministic=deterministic,
|
||||
return_episode_rewards=True,
|
||||
)
|
||||
print(reward_data)
|
||||
|
||||
def save(self, save_path: Path) -> None:
|
||||
"""
|
||||
Save the current policy parameters.
|
||||
|
||||
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
|
||||
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
|
||||
Therefore, this method is only used to save the final model.
|
||||
"""
|
||||
self._agent.save(save_path)
|
||||
|
||||
def load(self, model_path: Path) -> None:
|
||||
"""Load agent from a checkpoint."""
|
||||
self._agent = self._agent_class.load(model_path, env=self.session.env)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
|
||||
"""Create an agent from config file."""
|
||||
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
|
||||
@@ -1,18 +1,20 @@
|
||||
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
|
||||
from enum import Enum
|
||||
from ipaddress import IPv4Address
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
|
||||
|
||||
from arcd_gate.client.gate_client import ActType, GATEClient
|
||||
from gymnasium import spaces
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
from gymnasium.spaces.utils import flatten, flatten_space
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.actions import ActionManager
|
||||
from primaite.game.agent.interface import AbstractAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationSpace
|
||||
from primaite.game.agent.interface import AbstractAgent, ProxyAgent, RandomAgent
|
||||
from primaite.game.agent.observations import ObservationManager
|
||||
from primaite.game.agent.rewards import RewardFunction
|
||||
from primaite.game.io import SessionIO, SessionIOSettings
|
||||
from primaite.game.policy.policy import PolicyABC
|
||||
from primaite.simulator.network.hardware.base import Link, NIC, Node
|
||||
from primaite.simulator.network.hardware.nodes.computer import Computer
|
||||
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
|
||||
@@ -34,109 +36,62 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
class PrimaiteGATEClient(GATEClient):
|
||||
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
|
||||
class PrimaiteGymEnv(gymnasium.Env):
|
||||
"""
|
||||
Thin wrapper env to provide agents with a gymnasium API.
|
||||
|
||||
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
|
||||
"""
|
||||
Create a new GATE client for PrimAITE.
|
||||
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
:param parent_session: The parent session object.
|
||||
:type parent_session: PrimaiteSession
|
||||
:param service_port: The port on which the GATE service is running.
|
||||
:type service_port: int, optional
|
||||
"""
|
||||
super().__init__(service_port=service_port)
|
||||
self.parent_session: "PrimaiteSession" = parent_session
|
||||
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.session: "PrimaiteSession" = session
|
||||
self.agent: ProxyAgent = agents[0]
|
||||
|
||||
@property
|
||||
def rl_framework(self) -> str:
|
||||
"""The reinforcement learning framework to use."""
|
||||
return self.parent_session.training_options.rl_framework
|
||||
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
|
||||
"""Perform a step in the environment."""
|
||||
# make ProxyAgent store the action chosen my the RL policy
|
||||
self.agent.store_action(action)
|
||||
# apply_agent_actions accesses the action we just stored
|
||||
self.session.apply_agent_actions()
|
||||
self.session.advance_timestep()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
|
||||
@property
|
||||
def rl_algorithm(self) -> str:
|
||||
"""The reinforcement learning algorithm to use."""
|
||||
return self.parent_session.training_options.rl_algorithm
|
||||
|
||||
@property
|
||||
def seed(self) -> int | None:
|
||||
"""The seed to use for the environment's random number generator."""
|
||||
return self.parent_session.training_options.seed
|
||||
|
||||
@property
|
||||
def n_learn_episodes(self) -> int:
|
||||
"""The number of episodes in each learning run."""
|
||||
return self.parent_session.training_options.n_learn_episodes
|
||||
|
||||
@property
|
||||
def n_learn_steps(self) -> int:
|
||||
"""The number of steps in each learning episode."""
|
||||
return self.parent_session.training_options.n_learn_steps
|
||||
|
||||
@property
|
||||
def n_eval_episodes(self) -> int:
|
||||
"""The number of episodes in each evaluation run."""
|
||||
return self.parent_session.training_options.n_eval_episodes
|
||||
|
||||
@property
|
||||
def n_eval_steps(self) -> int:
|
||||
"""The number of steps in each evaluation episode."""
|
||||
return self.parent_session.training_options.n_eval_steps
|
||||
|
||||
@property
|
||||
def action_space(self) -> spaces.Space:
|
||||
"""The gym action space of the agent."""
|
||||
return self.parent_session.rl_agent.action_space.space
|
||||
|
||||
@property
|
||||
def observation_space(self) -> spaces.Space:
|
||||
"""The gymnasium observation space of the agent."""
|
||||
return flatten_space(self.parent_session.rl_agent.observation_space.space)
|
||||
|
||||
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
|
||||
"""Take a step in the environment.
|
||||
|
||||
This method is called by GATE to advance the simulation by one timestep.
|
||||
|
||||
:param action: The agent's action.
|
||||
:type action: ActType
|
||||
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
|
||||
:rtype: Tuple[ObsType, float, bool, bool, Dict]
|
||||
"""
|
||||
self.parent_session.rl_agent.most_recent_action = action
|
||||
self.parent_session.step()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
rew = self.parent_session.rl_agent.reward_function.calculate(state)
|
||||
term = False
|
||||
trunc = False
|
||||
next_obs = self._get_obs()
|
||||
reward = self.agent.reward_function.current_reward
|
||||
terminated = False
|
||||
truncated = self.session.calculate_truncated()
|
||||
info = {}
|
||||
return obs, rew, term, trunc, info
|
||||
|
||||
def reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> Tuple[ObsType, Dict]:
|
||||
"""Reset the environment.
|
||||
return next_obs, reward, terminated, truncated, info
|
||||
|
||||
This method is called when the environment is initialized and at the end of each episode.
|
||||
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
|
||||
"""Reset the environment."""
|
||||
self.session.reset()
|
||||
state = self.session.get_sim_state()
|
||||
self.session.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
info = {}
|
||||
return next_obs, info
|
||||
|
||||
:param seed: The seed to use for the environment's random number generator.
|
||||
:type seed: int, optional
|
||||
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
|
||||
compatibility with GATE.
|
||||
:type options: dict[str, Any], optional
|
||||
:return: The initial observation and an empty info dictionary.
|
||||
:rtype: Tuple[ObsType, Dict]
|
||||
"""
|
||||
self.parent_session.reset()
|
||||
state = self.parent_session.simulation.describe_state()
|
||||
obs = self.parent_session.rl_agent.observation_space.observe(state)
|
||||
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
|
||||
return obs, {}
|
||||
@property
|
||||
def action_space(self) -> gymnasium.Space:
|
||||
"""Return the action space of the environment."""
|
||||
return self.agent.action_manager.space
|
||||
|
||||
def close(self):
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
self.parent_session.close()
|
||||
@property
|
||||
def observation_space(self) -> gymnasium.Space:
|
||||
"""Return the observation space of the environment."""
|
||||
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
|
||||
|
||||
def _get_obs(self) -> ObsType:
|
||||
"""Return the current observation."""
|
||||
unflat_space = self.agent.observation_manager.space
|
||||
unflat_obs = self.agent.observation_manager.current_observation
|
||||
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
|
||||
|
||||
|
||||
class PrimaiteSessionOptions(BaseModel):
|
||||
@@ -146,6 +101,8 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
ports: List[str]
|
||||
protocols: List[str]
|
||||
|
||||
@@ -153,48 +110,102 @@ class PrimaiteSessionOptions(BaseModel):
|
||||
class TrainingOptions(BaseModel):
|
||||
"""Options for training the RL agent."""
|
||||
|
||||
rl_framework: str
|
||||
rl_algorithm: str
|
||||
seed: Optional[int]
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
rl_framework: Literal["SB3", "RLLIB"]
|
||||
rl_algorithm: Literal["PPO", "A2C"]
|
||||
n_learn_episodes: int
|
||||
n_learn_steps: int
|
||||
n_eval_episodes: int
|
||||
n_eval_steps: int
|
||||
n_eval_episodes: Optional[int] = None
|
||||
max_steps_per_episode: int
|
||||
# checkpoint_freq: Optional[int] = None
|
||||
deterministic_eval: bool
|
||||
seed: Optional[int]
|
||||
n_agents: int
|
||||
agent_references: List[str]
|
||||
|
||||
|
||||
class SessionMode(Enum):
|
||||
"""Helper to keep track of the current session mode."""
|
||||
|
||||
TRAIN = "train"
|
||||
EVAL = "eval"
|
||||
MANUAL = "manual"
|
||||
|
||||
|
||||
class PrimaiteSession:
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
|
||||
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialise a PrimaiteSession object."""
|
||||
self.simulation: Simulation = Simulation()
|
||||
"""Simulation object with which the agents will interact."""
|
||||
|
||||
self.agents: List[AbstractAgent] = []
|
||||
"""List of agents."""
|
||||
self.rl_agent: AbstractAgent
|
||||
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
|
||||
|
||||
self.rl_agents: List[ProxyAgent] = []
|
||||
"""Subset of agent list including only the reinforcement learning agents."""
|
||||
|
||||
self.step_counter: int = 0
|
||||
"""Current timestep within the episode."""
|
||||
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.options: PrimaiteSessionOptions
|
||||
"""Special options that apply for the entire game."""
|
||||
|
||||
self.training_options: TrainingOptions
|
||||
"""Options specific to agent training."""
|
||||
|
||||
self.policy: PolicyABC
|
||||
"""The reinforcement learning policy."""
|
||||
|
||||
self.ref_map_nodes: Dict[str, Node] = {}
|
||||
"""Mapping from unique node reference name to node object. Used when parsing config files."""
|
||||
|
||||
self.ref_map_services: Dict[str, Service] = {}
|
||||
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
|
||||
|
||||
self.ref_map_applications: Dict[str, Application] = {}
|
||||
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
|
||||
|
||||
self.ref_map_links: Dict[str, Link] = {}
|
||||
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
|
||||
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
|
||||
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
|
||||
|
||||
self.env: PrimaiteGymEnv
|
||||
"""The environment that the agent can consume. Could be PrimaiteEnv."""
|
||||
|
||||
self.mode: SessionMode = SessionMode.MANUAL
|
||||
"""Current session mode."""
|
||||
|
||||
self.io_manager = SessionIO()
|
||||
"""IO manager for the session."""
|
||||
|
||||
def start_session(self) -> None:
|
||||
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
|
||||
self.gate_client.start()
|
||||
"""Commence the training session."""
|
||||
self.mode = SessionMode.TRAIN
|
||||
n_learn_episodes = self.training_options.n_learn_episodes
|
||||
n_eval_episodes = self.training_options.n_eval_episodes
|
||||
max_steps_per_episode = self.training_options.max_steps_per_episode
|
||||
|
||||
deterministic_eval = self.training_options.deterministic_eval
|
||||
self.policy.learn(
|
||||
n_episodes=n_learn_episodes,
|
||||
timesteps_per_episode=max_steps_per_episode,
|
||||
)
|
||||
self.save_models()
|
||||
|
||||
self.mode = SessionMode.EVAL
|
||||
if n_eval_episodes > 0:
|
||||
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
|
||||
|
||||
self.mode = SessionMode.MANUAL
|
||||
|
||||
def save_models(self) -> None:
|
||||
"""Save the RL models."""
|
||||
save_path = self.io_manager.generate_model_save_path("temp_model_name")
|
||||
self.policy.save(save_path)
|
||||
|
||||
def step(self):
|
||||
"""
|
||||
@@ -208,57 +219,76 @@ class PrimaiteSession:
|
||||
4. Each agent chooses an action based on the observation.
|
||||
5. Each agent converts the action to a request.
|
||||
6. The simulation applies the requests.
|
||||
|
||||
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
|
||||
interacts with should implement a step method that calls methods used by this method. For example, if using a
|
||||
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
|
||||
``self.apply_agent_actions()``.
|
||||
"""
|
||||
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
|
||||
# currently designed with assumption that all agents act once per step in order
|
||||
|
||||
# Get the current state of the simulation
|
||||
sim_state = self.get_sim_state()
|
||||
|
||||
# Update agents' observations and rewards based on the current state
|
||||
self.update_agents(sim_state)
|
||||
|
||||
# Apply all actions to simulation as requests
|
||||
self.apply_agent_actions()
|
||||
|
||||
# Advance timestep
|
||||
self.advance_timestep()
|
||||
|
||||
def get_sim_state(self) -> Dict:
|
||||
"""Get the current state of the simulation."""
|
||||
return self.simulation.describe_state()
|
||||
|
||||
def update_agents(self, state: Dict) -> None:
|
||||
"""Update agents' observations and rewards based on the current state."""
|
||||
for agent in self.agents:
|
||||
# 3. primaite session asks simulation to provide initial state
|
||||
# 4. primate session gives state to all agents
|
||||
# 5. primaite session asks agents to produce an action based on most recent state
|
||||
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
|
||||
sim_state = self.simulation.describe_state()
|
||||
agent.update_observation(state)
|
||||
agent.update_reward(state)
|
||||
|
||||
# 6. each agent takes most recent state and converts it to CAOS observation
|
||||
agent_obs = agent.convert_state_to_obs(sim_state)
|
||||
def apply_agent_actions(self) -> None:
|
||||
"""Apply all actions to simulation as requests."""
|
||||
for agent in self.agents:
|
||||
obs = agent.observation_manager.current_observation
|
||||
rew = agent.reward_function.current_reward
|
||||
action_choice, options = agent.get_action(obs, rew)
|
||||
request = agent.format_request(action_choice, options)
|
||||
self.simulation.apply_request(request)
|
||||
|
||||
# 7. meanwhile each agent also takes state and calculates reward
|
||||
agent_reward = agent.calculate_reward_from_state(sim_state)
|
||||
|
||||
# 8. each agent takes observation and applies decision rule to observation to create CAOS
|
||||
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
|
||||
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
|
||||
# code should live inside of the GATE agent subclass)
|
||||
# gets action in CAOS format
|
||||
_LOGGER.debug("Getting agent action")
|
||||
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
|
||||
# 9. CAOS action is converted into request (extra information might be needed to enrich
|
||||
# the request, this is what the execution definition is there for)
|
||||
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
|
||||
agent_request = agent.format_request(agent_action, action_options)
|
||||
|
||||
# 10. primaite session receives the action from the agents and asks the simulation to apply each
|
||||
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
|
||||
self.simulation.apply_request(agent_request)
|
||||
|
||||
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
def advance_timestep(self) -> None:
|
||||
"""Advance timestep."""
|
||||
self.step_counter += 1
|
||||
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
|
||||
self.simulation.apply_timestep(self.step_counter)
|
||||
|
||||
def calculate_truncated(self) -> bool:
|
||||
"""Calculate whether the episode is truncated."""
|
||||
current_step = self.step_counter
|
||||
max_steps = self.training_options.max_steps_per_episode
|
||||
if current_step >= max_steps:
|
||||
return True
|
||||
return False
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the session, this will reset the simulation."""
|
||||
return NotImplemented
|
||||
self.episode_counter += 1
|
||||
self.step_counter = 0
|
||||
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
|
||||
self.simulation.reset_component_for_episode(self.episode_counter)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the session, this will stop the gate client and close the simulation."""
|
||||
"""Close the session, this will stop the env and close the simulation."""
|
||||
return NotImplemented
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: dict) -> "PrimaiteSession":
|
||||
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
|
||||
"""Create a PrimaiteSession object from a config dictionary.
|
||||
|
||||
The config dictionary should have the following top-level keys:
|
||||
1. training_config: options for training the RL agent. Used by GATE.
|
||||
1. training_config: options for training the RL agent.
|
||||
2. game_config: options for the game itself. Used by PrimaiteSession.
|
||||
3. simulation: defines the network topology and the initial state of the simulation.
|
||||
|
||||
@@ -276,6 +306,11 @@ class PrimaiteSession:
|
||||
protocols=cfg["game_config"]["protocols"],
|
||||
)
|
||||
sess.training_options = TrainingOptions(**cfg["training_config"])
|
||||
|
||||
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
|
||||
io_settings = cfg.get("io_settings", {})
|
||||
sess.io_manager.settings = SessionIOSettings(**io_settings)
|
||||
|
||||
sim = sess.simulation
|
||||
net = sim.network
|
||||
|
||||
@@ -412,7 +447,7 @@ class PrimaiteSession:
|
||||
reward_function_cfg = agent_cfg["reward_function"]
|
||||
|
||||
# CREATE OBSERVATION SPACE
|
||||
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
|
||||
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
|
||||
|
||||
# CREATE ACTION SPACE
|
||||
action_space_cfg["options"]["node_uuids"] = []
|
||||
@@ -448,15 +483,15 @@ class PrimaiteSession:
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
elif agent_type == "GATERLAgent":
|
||||
new_agent = RandomAgent(
|
||||
elif agent_type == "ProxyAgent":
|
||||
new_agent = ProxyAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
action_space=action_space,
|
||||
observation_space=obs_space,
|
||||
reward_function=rew_function,
|
||||
)
|
||||
sess.agents.append(new_agent)
|
||||
sess.rl_agent = new_agent
|
||||
sess.rl_agents.append(new_agent)
|
||||
elif agent_type == "RedDatabaseCorruptingAgent":
|
||||
new_agent = RandomAgent(
|
||||
agent_name=agent_cfg["ref"],
|
||||
@@ -468,4 +503,12 @@ class PrimaiteSession:
|
||||
else:
|
||||
print("agent type not found")
|
||||
|
||||
# CREATE ENVIRONMENT
|
||||
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
|
||||
|
||||
# CREATE POLICY
|
||||
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
|
||||
if agent_load_path:
|
||||
sess.policy.load(Path(agent_load_path))
|
||||
|
||||
return sess
|
||||
|
||||
@@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__)
|
||||
|
||||
def run(
|
||||
config_path: Optional[Union[str, Path]] = "",
|
||||
agent_load_path: Optional[Union[str, Path]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Run the PrimAITE Session.
|
||||
@@ -31,7 +32,7 @@ def run(
|
||||
otherwise False.
|
||||
"""
|
||||
cfg = load(config_path)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg)
|
||||
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
|
||||
sess.start_session()
|
||||
|
||||
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory."""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from primaite import _PRIMAITE_ROOT
|
||||
|
||||
SIM_OUTPUT = None
|
||||
"A path at the repo root dir to use temporarily for sim output testing while in dev."
|
||||
# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path
|
||||
__all__ = ["SIM_OUTPUT"]
|
||||
|
||||
if not SIM_OUTPUT:
|
||||
session_timestamp = datetime.now()
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path
|
||||
SIM_OUTPUT.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
class __SimOutput:
|
||||
def __init__(self):
|
||||
self._path: Path = (
|
||||
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
||||
)
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return self._path
|
||||
|
||||
@path.setter
|
||||
def path(self, new_path: Path) -> None:
|
||||
self._path = new_path
|
||||
self._path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
|
||||
SIM_OUTPUT = __SimOutput()
|
||||
|
||||
@@ -957,7 +957,7 @@ class Node(SimComponent):
|
||||
if not kwargs.get("session_manager"):
|
||||
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
|
||||
if not kwargs.get("root"):
|
||||
kwargs["root"] = SIM_OUTPUT / kwargs["hostname"]
|
||||
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
|
||||
if not kwargs.get("file_system"):
|
||||
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
|
||||
if not kwargs.get("software_manager"):
|
||||
|
||||
@@ -75,7 +75,7 @@ class PacketCapture:
|
||||
|
||||
def _get_log_path(self) -> Path:
|
||||
"""Get the path for the log file."""
|
||||
root = SIM_OUTPUT / self.hostname
|
||||
root = SIM_OUTPUT.path / self.hostname
|
||||
root.mkdir(exist_ok=True, parents=True)
|
||||
return root / f"{self._logger_name}.log"
|
||||
|
||||
|
||||
@@ -41,7 +41,6 @@ class SysLog:
|
||||
JSON-like messages.
|
||||
"""
|
||||
log_path = self._get_log_path()
|
||||
|
||||
file_handler = logging.FileHandler(filename=log_path)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
|
||||
@@ -81,7 +80,7 @@ class SysLog:
|
||||
|
||||
:return: Path object representing the location of the log file.
|
||||
"""
|
||||
root = SIM_OUTPUT / self.hostname
|
||||
root = SIM_OUTPUT.path / self.hostname
|
||||
root.mkdir(exist_ok=True, parents=True)
|
||||
return root / f"{self.hostname}_sys.log"
|
||||
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
"""Utility script to start the gate server for running PrimAITE in attached mode."""
|
||||
from arcd_gate.server.gate_service import GATEService
|
||||
|
||||
|
||||
def start_gate_server():
|
||||
"""Start the gate server."""
|
||||
service = GATEService()
|
||||
service.start()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
start_gate_server()
|
||||
725
tests/assets/configs/bad_primaite_session.yaml
Normal file
725
tests/assets/configs/bad_primaite_session.yaml
Normal file
@@ -0,0 +1,725 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
se3ed: 333 # Purposeful typo to check that error is raised with bad configuration.
|
||||
n_learn_steps: 2560
|
||||
n_eval_episodes: 5
|
||||
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
observations:
|
||||
- logon_status
|
||||
- operating_status
|
||||
services:
|
||||
- service_ref: data_manipulation_bot
|
||||
observations:
|
||||
operating_status
|
||||
health_status
|
||||
folders: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
#<not yet implemented
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# server_ip: 192.168.1.14
|
||||
# payload: "DROP TABLE IF EXISTS user;"
|
||||
# success_rate: 80%
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
# - type: NODE_FOLDER_DELETE
|
||||
# - type: NODE_FOLDER_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
9:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
13:
|
||||
action: "NODE_FILE_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
20:
|
||||
action: "NODE_STARTUP"
|
||||
options:
|
||||
node_id: 6
|
||||
21:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 6
|
||||
22:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
28:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 0
|
||||
29:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 1
|
||||
30:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 2
|
||||
31:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 3
|
||||
32:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 4
|
||||
33:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 5
|
||||
34:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 6
|
||||
35:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 7
|
||||
36:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 8
|
||||
37:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 9
|
||||
38:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
40:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
41:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
42:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
43:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
44:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
45:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
46:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
47:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
48:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
49:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
50:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
51:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
52:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
53:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: router_1
|
||||
- node_ref: switch_1
|
||||
- node_ref: switch_2
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_ref: database_server
|
||||
- node_ref: backup_server
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
type: router
|
||||
hostname: router_1
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
type: switch
|
||||
hostname: switch_2
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
type: server
|
||||
hostname: domain_controller
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- ref: web_server
|
||||
type: server
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
hostname: database_server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
|
||||
- ref: backup_server
|
||||
type: server
|
||||
hostname: backup_server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
hostname: security_suite
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
nics:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- ref: client_1
|
||||
type: computer
|
||||
hostname: client_1
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
|
||||
- ref: client_2
|
||||
type: computer
|
||||
hostname: client_2
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_ref: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_ref: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 2
|
||||
729
tests/assets/configs/eval_only_primaite_session.yaml
Normal file
729
tests/assets/configs/eval_only_primaite_session.yaml
Normal file
@@ -0,0 +1,729 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 0
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
observations:
|
||||
- logon_status
|
||||
- operating_status
|
||||
services:
|
||||
- service_ref: data_manipulation_bot
|
||||
observations:
|
||||
operating_status
|
||||
health_status
|
||||
folders: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
#<not yet implemented
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# server_ip: 192.168.1.14
|
||||
# payload: "DROP TABLE IF EXISTS user;"
|
||||
# success_rate: 80%
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
# - type: NODE_FOLDER_DELETE
|
||||
# - type: NODE_FOLDER_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
9:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
13:
|
||||
action: "NODE_FILE_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
20:
|
||||
action: "NODE_STARTUP"
|
||||
options:
|
||||
node_id: 6
|
||||
21:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 6
|
||||
22:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
28:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 0
|
||||
29:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 1
|
||||
30:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 2
|
||||
31:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 3
|
||||
32:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 4
|
||||
33:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 5
|
||||
34:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 6
|
||||
35:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 7
|
||||
36:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 8
|
||||
37:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 9
|
||||
38:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
40:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
41:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
42:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
43:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
44:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
45:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
46:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
47:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
48:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
49:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
50:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
51:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
52:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
53:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: router_1
|
||||
- node_ref: switch_1
|
||||
- node_ref: switch_2
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_ref: database_server
|
||||
- node_ref: backup_server
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
type: router
|
||||
hostname: router_1
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
type: switch
|
||||
hostname: switch_2
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
type: server
|
||||
hostname: domain_controller
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- ref: web_server
|
||||
type: server
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
hostname: database_server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
|
||||
- ref: backup_server
|
||||
type: server
|
||||
hostname: backup_server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
hostname: security_suite
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
nics:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- ref: client_1
|
||||
type: computer
|
||||
hostname: client_1
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
|
||||
- ref: client_2
|
||||
type: computer
|
||||
hostname: client_2
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_ref: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_ref: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 2
|
||||
733
tests/assets/configs/test_primaite_session.yaml
Normal file
733
tests/assets/configs/test_primaite_session.yaml
Normal file
@@ -0,0 +1,733 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 10
|
||||
n_eval_episodes: 5
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
io_settings:
|
||||
save_checkpoints: true
|
||||
checkpoint_interval: 5
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
observations:
|
||||
- logon_status
|
||||
- operating_status
|
||||
services:
|
||||
- service_ref: data_manipulation_bot
|
||||
observations:
|
||||
operating_status
|
||||
health_status
|
||||
folders: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
#<not yet implemented
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# server_ip: 192.168.1.14
|
||||
# payload: "DROP TABLE IF EXISTS user;"
|
||||
# success_rate: 80%
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
# - type: NODE_FOLDER_DELETE
|
||||
# - type: NODE_FOLDER_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
9:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
13:
|
||||
action: "NODE_FILE_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
20:
|
||||
action: "NODE_STARTUP"
|
||||
options:
|
||||
node_id: 6
|
||||
21:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 6
|
||||
22:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
28:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 0
|
||||
29:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 1
|
||||
30:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 2
|
||||
31:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 3
|
||||
32:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 4
|
||||
33:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 5
|
||||
34:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 6
|
||||
35:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 7
|
||||
36:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 8
|
||||
37:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 9
|
||||
38:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
40:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
41:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
42:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
43:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
44:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
45:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
46:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
47:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
48:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
49:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
50:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
51:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
52:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
53:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: router_1
|
||||
- node_ref: switch_1
|
||||
- node_ref: switch_2
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_ref: database_server
|
||||
- node_ref: backup_server
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
type: router
|
||||
hostname: router_1
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
type: switch
|
||||
hostname: switch_2
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
type: server
|
||||
hostname: domain_controller
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- ref: web_server
|
||||
type: server
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
hostname: database_server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
|
||||
- ref: backup_server
|
||||
type: server
|
||||
hostname: backup_server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
hostname: security_suite
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
nics:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- ref: client_1
|
||||
type: computer
|
||||
hostname: client_1
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
|
||||
- ref: client_2
|
||||
type: computer
|
||||
hostname: client_2
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_ref: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_ref: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 2
|
||||
729
tests/assets/configs/train_only_primaite_session.yaml
Normal file
729
tests/assets/configs/train_only_primaite_session.yaml
Normal file
@@ -0,0 +1,729 @@
|
||||
training_config:
|
||||
rl_framework: SB3
|
||||
rl_algorithm: PPO
|
||||
seed: 333
|
||||
n_learn_episodes: 10
|
||||
n_eval_episodes: 0
|
||||
max_steps_per_episode: 128
|
||||
deterministic_eval: false
|
||||
n_agents: 1
|
||||
agent_references:
|
||||
- defender
|
||||
|
||||
|
||||
game_config:
|
||||
ports:
|
||||
- ARP
|
||||
- DNS
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
|
||||
agents:
|
||||
- ref: client_1_green_user
|
||||
team: GREEN
|
||||
type: GreenWebBrowsingAgent
|
||||
observation_space:
|
||||
type: UC2GreenObservation
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
# <not yet implemented>
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# target_address: arcd.com
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
max_nics_per_node: 2
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_step: 5
|
||||
frequency: 4
|
||||
variance: 3
|
||||
|
||||
- ref: client_1_data_manipulation_red_bot
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2RedObservation
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
observations:
|
||||
- logon_status
|
||||
- operating_status
|
||||
services:
|
||||
- service_ref: data_manipulation_bot
|
||||
observations:
|
||||
operating_status
|
||||
health_status
|
||||
folders: {}
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
#<not yet implemented
|
||||
# - type: NODE_APPLICATION_EXECUTE
|
||||
# options:
|
||||
# execution_definition:
|
||||
# server_ip: 192.168.1.14
|
||||
# payload: "DROP TABLE IF EXISTS user;"
|
||||
# success_rate: 80%
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_CORRUPT
|
||||
# - type: NODE_FOLDER_DELETE
|
||||
# - type: NODE_FOLDER_CORRUPT
|
||||
- type: NODE_OS_SCAN
|
||||
# - type: NODE_LOGON
|
||||
# - type: NODE_LOGOFF
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: client_1
|
||||
max_folders_per_node: 1
|
||||
max_files_per_folder: 1
|
||||
max_services_per_node: 1
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
|
||||
start_step: 25
|
||||
frequency: 20
|
||||
variance: 5
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
|
||||
observation_space:
|
||||
type: UC2BlueObservation
|
||||
options:
|
||||
num_services_per_node: 1
|
||||
num_folders_per_node: 1
|
||||
num_files_per_folder: 1
|
||||
num_nics_per_node: 2
|
||||
nodes:
|
||||
- node_ref: domain_controller
|
||||
services:
|
||||
- service_ref: domain_controller_dns_server
|
||||
- node_ref: web_server
|
||||
services:
|
||||
- service_ref: web_server_database_client
|
||||
- node_ref: database_server
|
||||
services:
|
||||
- service_ref: database_service
|
||||
folders:
|
||||
- folder_name: database
|
||||
files:
|
||||
- file_name: database.db
|
||||
- node_ref: backup_server
|
||||
# services:
|
||||
# - service_ref: backup_service
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
links:
|
||||
- link_ref: router_1___switch_1
|
||||
- link_ref: router_1___switch_2
|
||||
- link_ref: switch_1___domain_controller
|
||||
- link_ref: switch_1___web_server
|
||||
- link_ref: switch_1___database_server
|
||||
- link_ref: switch_1___backup_server
|
||||
- link_ref: switch_1___security_suite
|
||||
- link_ref: switch_2___client_1
|
||||
- link_ref: switch_2___client_2
|
||||
- link_ref: switch_2___security_suite
|
||||
acl:
|
||||
options:
|
||||
max_acl_rules: 10
|
||||
router_node_ref: router_1
|
||||
ip_address_order:
|
||||
- node_ref: domain_controller
|
||||
nic_num: 1
|
||||
- node_ref: web_server
|
||||
nic_num: 1
|
||||
- node_ref: database_server
|
||||
nic_num: 1
|
||||
- node_ref: backup_server
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 1
|
||||
- node_ref: client_1
|
||||
nic_num: 1
|
||||
- node_ref: client_2
|
||||
nic_num: 1
|
||||
- node_ref: security_suite
|
||||
nic_num: 2
|
||||
ics: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SERVICE_SCAN
|
||||
- type: NODE_SERVICE_STOP
|
||||
- type: NODE_SERVICE_START
|
||||
- type: NODE_SERVICE_PAUSE
|
||||
- type: NODE_SERVICE_RESUME
|
||||
- type: NODE_SERVICE_RESTART
|
||||
- type: NODE_SERVICE_DISABLE
|
||||
- type: NODE_SERVICE_ENABLE
|
||||
- type: NODE_FILE_SCAN
|
||||
- type: NODE_FILE_CHECKHASH
|
||||
- type: NODE_FILE_DELETE
|
||||
- type: NODE_FILE_REPAIR
|
||||
- type: NODE_FILE_RESTORE
|
||||
- type: NODE_FOLDER_SCAN
|
||||
- type: NODE_FOLDER_CHECKHASH
|
||||
- type: NODE_FOLDER_REPAIR
|
||||
- type: NODE_FOLDER_RESTORE
|
||||
- type: NODE_OS_SCAN
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: NODE_RESET
|
||||
- type: NETWORK_ACL_ADDRULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_ACL_REMOVERULE
|
||||
options:
|
||||
target_router_ref: router_1
|
||||
- type: NETWORK_NIC_ENABLE
|
||||
- type: NETWORK_NIC_DISABLE
|
||||
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
# scan webapp service
|
||||
1:
|
||||
action: NODE_SERVICE_SCAN
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# stop webapp service
|
||||
2:
|
||||
action: NODE_SERVICE_STOP
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
# start webapp service
|
||||
3:
|
||||
action: "NODE_SERVICE_START"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
4:
|
||||
action: "NODE_SERVICE_PAUSE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
5:
|
||||
action: "NODE_SERVICE_RESUME"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
6:
|
||||
action: "NODE_SERVICE_RESTART"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
7:
|
||||
action: "NODE_SERVICE_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
8:
|
||||
action: "NODE_SERVICE_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
service_id: 1
|
||||
9:
|
||||
action: "NODE_FILE_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
10:
|
||||
action: "NODE_FILE_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
11:
|
||||
action: "NODE_FILE_DELETE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
12:
|
||||
action: "NODE_FILE_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
13:
|
||||
action: "NODE_FILE_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
file_id: 1
|
||||
14:
|
||||
action: "NODE_FOLDER_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
15:
|
||||
action: "NODE_FOLDER_CHECKHASH"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
16:
|
||||
action: "NODE_FOLDER_REPAIR"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
17:
|
||||
action: "NODE_FOLDER_RESTORE"
|
||||
options:
|
||||
node_id: 3
|
||||
folder_id: 1
|
||||
18:
|
||||
action: "NODE_OS_SCAN"
|
||||
options:
|
||||
node_id: 3
|
||||
19:
|
||||
action: "NODE_SHUTDOWN"
|
||||
options:
|
||||
node_id: 6
|
||||
20:
|
||||
action: "NODE_STARTUP"
|
||||
options:
|
||||
node_id: 6
|
||||
21:
|
||||
action: "NODE_RESET"
|
||||
options:
|
||||
node_id: 6
|
||||
22:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
23:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 1
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 1
|
||||
24:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
25:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 3
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
26:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 7
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
27:
|
||||
action: "NETWORK_ACL_ADDRULE"
|
||||
options:
|
||||
position: 1
|
||||
permission: 2
|
||||
source_ip_id: 8
|
||||
dest_ip_id: 4
|
||||
source_port_id: 1
|
||||
dest_port_id: 1
|
||||
protocol_id: 3
|
||||
28:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 0
|
||||
29:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 1
|
||||
30:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 2
|
||||
31:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 3
|
||||
32:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 4
|
||||
33:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 5
|
||||
34:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 6
|
||||
35:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 7
|
||||
36:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 8
|
||||
37:
|
||||
action: "NETWORK_ACL_REMOVERULE"
|
||||
options:
|
||||
position: 9
|
||||
38:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
39:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 1
|
||||
40:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
41:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 2
|
||||
nic_id: 1
|
||||
42:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
43:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 3
|
||||
nic_id: 1
|
||||
44:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
45:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 4
|
||||
nic_id: 1
|
||||
46:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
47:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 1
|
||||
48:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
49:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 5
|
||||
nic_id: 2
|
||||
50:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
51:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 6
|
||||
nic_id: 1
|
||||
52:
|
||||
action: "NETWORK_NIC_DISABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
53:
|
||||
action: "NETWORK_NIC_ENABLE"
|
||||
options:
|
||||
node_id: 7
|
||||
nic_id: 1
|
||||
|
||||
|
||||
options:
|
||||
nodes:
|
||||
- node_ref: router_1
|
||||
- node_ref: switch_1
|
||||
- node_ref: switch_2
|
||||
- node_ref: domain_controller
|
||||
- node_ref: web_server
|
||||
- node_ref: database_server
|
||||
- node_ref: backup_server
|
||||
- node_ref: security_suite
|
||||
- node_ref: client_1
|
||||
- node_ref: client_2
|
||||
max_folders_per_node: 2
|
||||
max_files_per_folder: 2
|
||||
max_services_per_node: 2
|
||||
max_nics_per_node: 8
|
||||
max_acl_rules: 10
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
|
||||
- type: WEB_SERVER_404_PENALTY
|
||||
weight: 0.5
|
||||
options:
|
||||
node_ref: web_server
|
||||
service_ref: web_server_web_service
|
||||
|
||||
|
||||
agent_settings:
|
||||
# ...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
|
||||
- ref: router_1
|
||||
type: router
|
||||
hostname: router_1
|
||||
num_ports: 5
|
||||
ports:
|
||||
1:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
2:
|
||||
ip_address: 192.168.1.1
|
||||
subnet_mask: 255.255.255.0
|
||||
acl:
|
||||
0:
|
||||
action: PERMIT
|
||||
src_port: POSTGRES_SERVER
|
||||
dst_port: POSTGRES_SERVER
|
||||
1:
|
||||
action: PERMIT
|
||||
src_port: DNS
|
||||
dst_port: DNS
|
||||
22:
|
||||
action: PERMIT
|
||||
src_port: ARP
|
||||
dst_port: ARP
|
||||
23:
|
||||
action: PERMIT
|
||||
protocol: ICMP
|
||||
|
||||
- ref: switch_1
|
||||
type: switch
|
||||
hostname: switch_1
|
||||
num_ports: 8
|
||||
|
||||
- ref: switch_2
|
||||
type: switch
|
||||
hostname: switch_2
|
||||
num_ports: 8
|
||||
|
||||
- ref: domain_controller
|
||||
type: server
|
||||
hostname: domain_controller
|
||||
ip_address: 192.168.1.10
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- ref: domain_controller_dns_server
|
||||
type: DNSServer
|
||||
options:
|
||||
domain_mapping:
|
||||
arcd.com: 192.168.1.12 # web server
|
||||
|
||||
- ref: web_server
|
||||
type: server
|
||||
hostname: web_server
|
||||
ip_address: 192.168.1.12
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.10
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: web_server_database_client
|
||||
type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.14
|
||||
- ref: web_server_web_service
|
||||
type: WebServer
|
||||
|
||||
|
||||
- ref: database_server
|
||||
type: server
|
||||
hostname: database_server
|
||||
ip_address: 192.168.1.14
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: database_service
|
||||
type: DatabaseService
|
||||
|
||||
- ref: backup_server
|
||||
type: server
|
||||
hostname: backup_server
|
||||
ip_address: 192.168.1.16
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: backup_service
|
||||
type: DatabaseBackup
|
||||
|
||||
- ref: security_suite
|
||||
type: server
|
||||
hostname: security_suite
|
||||
ip_address: 192.168.1.110
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
dns_server: 192.168.1.10
|
||||
nics:
|
||||
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
|
||||
ip_address: 192.168.10.110
|
||||
subnet_mask: 255.255.255.0
|
||||
|
||||
- ref: client_1
|
||||
type: computer
|
||||
hostname: client_1
|
||||
ip_address: 192.168.10.21
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
services:
|
||||
- ref: data_manipulation_bot
|
||||
type: DataManipulationBot
|
||||
- ref: client_1_dns_client
|
||||
type: DNSClient
|
||||
|
||||
- ref: client_2
|
||||
type: computer
|
||||
hostname: client_2
|
||||
ip_address: 192.168.10.22
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.10.1
|
||||
dns_server: 192.168.1.10
|
||||
applications:
|
||||
- ref: client_2_web_browser
|
||||
type: WebBrowser
|
||||
services:
|
||||
- ref: client_2_dns_client
|
||||
type: DNSClient
|
||||
|
||||
links:
|
||||
- ref: router_1___switch_1
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: switch_1
|
||||
endpoint_b_port: 8
|
||||
- ref: router_1___switch_2
|
||||
endpoint_a_ref: router_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: switch_2
|
||||
endpoint_b_port: 8
|
||||
- ref: switch_1___domain_controller
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: domain_controller
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___web_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: web_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___database_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 3
|
||||
endpoint_b_ref: database_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___backup_server
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 4
|
||||
endpoint_b_ref: backup_server
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_1___security_suite
|
||||
endpoint_a_ref: switch_1
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_1
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_ref: client_1
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___client_2
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 2
|
||||
endpoint_b_ref: client_2
|
||||
endpoint_b_port: 1
|
||||
- ref: switch_2___security_suite
|
||||
endpoint_a_ref: switch_2
|
||||
endpoint_a_port: 7
|
||||
endpoint_b_ref: security_suite
|
||||
endpoint_b_port: 2
|
||||
@@ -5,12 +5,13 @@ import tempfile
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import nodeenv
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.game.session import PrimaiteSession
|
||||
|
||||
# from primaite.environment.primaite_env import Primaite
|
||||
# from primaite.primaite_session import PrimaiteSession
|
||||
@@ -20,13 +21,15 @@ from primaite.simulator.network.transmission.transport_layer import Port
|
||||
from primaite.simulator.system.applications.application import Application
|
||||
from primaite.simulator.system.core.sys_log import SysLog
|
||||
from primaite.simulator.system.services.service import Service
|
||||
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
|
||||
from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path
|
||||
|
||||
ACTION_SPACE_NODE_VALUES = 1
|
||||
ACTION_SPACE_NODE_ACTION_VALUES = 1
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
from primaite import PRIMAITE_PATHS
|
||||
|
||||
# PrimAITE v3 stuff
|
||||
from primaite.simulator.file_system.file_system import FileSystem
|
||||
from primaite.simulator.network.hardware.base import Node
|
||||
@@ -71,102 +74,32 @@ def file_system() -> FileSystem:
|
||||
|
||||
|
||||
# PrimAITE v2 stuff
|
||||
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
class TempPrimaiteSession: # PrimaiteSession):
|
||||
class TempPrimaiteSession(PrimaiteSession):
|
||||
"""
|
||||
A temporary PrimaiteSession class.
|
||||
|
||||
Uses context manager for deletion of files upon exit.
|
||||
"""
|
||||
|
||||
# def __init__(
|
||||
# self,
|
||||
# training_config_path: Union[str, Path],
|
||||
# lay_down_config_path: Union[str, Path],
|
||||
# ):
|
||||
# super().__init__(training_config_path, lay_down_config_path)
|
||||
# self.setup()
|
||||
@classmethod
|
||||
def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession":
|
||||
"""Create a temporary PrimaiteSession object from a config file."""
|
||||
config_path = Path(config_path)
|
||||
with open(config_path, "r") as f:
|
||||
config = yaml.safe_load(f)
|
||||
|
||||
# @property
|
||||
# def env(self) -> Primaite:
|
||||
# """Direct access to the env for ease of testing."""
|
||||
# return self._agent_session._env # noqa
|
||||
return super().from_config(cfg=config)
|
||||
|
||||
# def __enter__(self):
|
||||
# return self
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
# def __exit__(self, type, value, tb):
|
||||
# shutil.rmtree(self.session_path)
|
||||
# _LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
|
||||
def __exit__(self, type, value, tb):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.fixture
|
||||
def temp_primaite_session(request):
|
||||
"""
|
||||
Provides a temporary PrimaiteSession instance.
|
||||
|
||||
It's temporary as it uses a temporary directory as the session path.
|
||||
|
||||
To use this fixture you need to:
|
||||
|
||||
- parametrize your test function with:
|
||||
|
||||
- "temp_primaite_session"
|
||||
- [[path to training config, path to lay down config]]
|
||||
- Include the temp_primaite_session fixture as a param in your test
|
||||
function.
|
||||
- use the temp_primaite_session as a context manager assigning is the
|
||||
name 'session'.
|
||||
|
||||
.. code:: python
|
||||
|
||||
from primaite.config.lay_down_config import dos_very_basic_config_path
|
||||
from primaite.config.training_config import main_training_config_path
|
||||
@pytest.mark.parametrize(
|
||||
"temp_primaite_session",
|
||||
[
|
||||
[main_training_config_path(), dos_very_basic_config_path()]
|
||||
],
|
||||
indirect=True
|
||||
)
|
||||
def test_primaite_session(temp_primaite_session):
|
||||
with temp_primaite_session as session:
|
||||
# Learning outputs are saved in session.learning_path
|
||||
session.learn()
|
||||
|
||||
# Evaluation outputs are saved in session.evaluation_path
|
||||
session.evaluate()
|
||||
|
||||
# To ensure that all files are written, you must call .close()
|
||||
session.close()
|
||||
|
||||
# If you need to inspect any session outputs, it must be done
|
||||
# inside the context manager
|
||||
|
||||
# Now that we've exited the context manager, the
|
||||
# session.session_path directory and its contents are deleted
|
||||
"""
|
||||
training_config_path = request.param[0]
|
||||
lay_down_config_path = request.param[1]
|
||||
with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck:
|
||||
mck.session_timestamp = datetime.now()
|
||||
|
||||
return TempPrimaiteSession(training_config_path, lay_down_config_path)
|
||||
|
||||
|
||||
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
|
||||
@pytest.fixture
|
||||
def temp_session_path() -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
:return: The session directory path.
|
||||
"""
|
||||
session_timestamp = datetime.now()
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path(tempfile.gettempdir()) / "_primaite" / date_dir / session_path
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
|
||||
"""Create a temporary PrimaiteSession object."""
|
||||
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
|
||||
config_path = request.param[0]
|
||||
return TempPrimaiteSession.from_config(config_path=config_path)
|
||||
|
||||
68
tests/e2e_integration_tests/test_primaite_session.py
Normal file
68
tests/e2e_integration_tests/test_primaite_session.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import pydantic
|
||||
import pytest
|
||||
|
||||
from tests.conftest import TempPrimaiteSession
|
||||
|
||||
CFG_PATH = "tests/assets/configs/test_primaite_session.yaml"
|
||||
TRAINING_ONLY_PATH = "tests/assets/configs/train_only_primaite_session.yaml"
|
||||
EVAL_ONLY_PATH = "tests/assets/configs/eval_only_primaite_session.yaml"
|
||||
MISCONFIGURED_PATH = "tests/assets/configs/bad_primaite_session.yaml"
|
||||
|
||||
|
||||
class TestPrimaiteSession:
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_creating_session(self, temp_primaite_session):
|
||||
"""Check that creating a session from config works."""
|
||||
with temp_primaite_session as session:
|
||||
if not isinstance(session, TempPrimaiteSession):
|
||||
raise AssertionError
|
||||
|
||||
assert session is not None
|
||||
assert session.simulation
|
||||
assert len(session.agents) == 3
|
||||
assert len(session.rl_agents) == 1
|
||||
|
||||
assert session.policy
|
||||
assert session.env
|
||||
|
||||
assert session.simulation.network
|
||||
assert len(session.simulation.network.nodes) == 10
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
|
||||
def test_start_session(self, temp_primaite_session):
|
||||
"""Make sure you can go all the way through the session without errors."""
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
session.start_session()
|
||||
|
||||
session_path = session.io_manager.session_path
|
||||
assert session_path.exists()
|
||||
print(list(session_path.glob("*")))
|
||||
checkpoint_dir = session_path / "checkpoints" / "sb3_final"
|
||||
assert checkpoint_dir.exists()
|
||||
checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip"
|
||||
checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip"
|
||||
checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip"
|
||||
assert checkpoint_1.exists()
|
||||
assert checkpoint_2.exists()
|
||||
assert not checkpoint_3.exists()
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
|
||||
def test_training_only_session(self, temp_primaite_session):
|
||||
"""Check that you can run a training-only session."""
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
session.start_session()
|
||||
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
|
||||
|
||||
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
|
||||
def test_eval_only_session(self, temp_primaite_session):
|
||||
"""Check that you can load a model and run an eval-only session."""
|
||||
with temp_primaite_session as session:
|
||||
session: TempPrimaiteSession
|
||||
session.start_session()
|
||||
# TODO: include checks that the model was loaded and that the eval-only session ran
|
||||
|
||||
def test_error_thrown_on_bad_configuration(self):
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
|
||||
@@ -9,7 +9,7 @@ from primaite import getLogger
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
def temp_user_sessions_path() -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user