Update data manipulation bot

This commit is contained in:
Jake Walker
2023-11-24 10:33:19 +00:00
41 changed files with 3648 additions and 499 deletions

View File

@@ -29,17 +29,6 @@ jobs:
pip install -e .[dev]
displayName: 'Install PrimAITE for docs autosummary'
- script: |
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
python -m pip install $GATE_WHEEL[dev]
displayName: 'Install GATE'
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
- script: |
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
displayName: 'Install GATE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
primaite setup
displayName: 'Perform PrimAITE Setup'

View File

@@ -81,17 +81,6 @@ stages:
displayName: 'Install PrimAITE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
GATE_WHEEL=$(ls ./GATE/arcd_gate*.whl)
python -m pip install $GATE_WHEEL[dev]
displayName: 'Install GATE'
condition: or(eq( variables['Agent.OS'], 'Linux' ), eq( variables['Agent.OS'], 'Darwin' ))
- script: |
forfiles /p GATE\ /m *.whl /c "cmd /c python -m pip install @file[dev]"
displayName: 'Install GATE'
condition: eq( variables['Agent.OS'], 'Windows_NT' )
- script: |
primaite setup
displayName: 'Perform PrimAITE Setup'

View File

@@ -46,7 +46,6 @@ python3 -m venv .venv
attrib +h .venv /s /d # Hides the .venv directory
.\.venv\Scripts\activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
pip install GATE/arcd_gate-0.1.0-py3-none-any.whl
primaite setup
```
@@ -75,7 +74,6 @@ cd ~/primaite
python3 -m venv .venv
source .venv/bin/activate
pip install https://github.com/Autonomous-Resilient-Cyber-Defence/PrimAITE/releases/download/v2.0.0/primaite-2.0.0-py3-none-any.whl
pip install arcd_gate-0.1.0-py3-none-any.whl
primaite setup
```
@@ -120,7 +118,6 @@ source venv/bin/activate
```bash
python3 -m pip install -e .[dev]
pip install arcd_gate-0.1.0-py3-none-any.whl
```
#### 6. Perform the PrimAITE setup:

View File

@@ -53,7 +53,7 @@ Example
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
data_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
data_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
data_manipulation_bot.run()
This would connect to the database service at 192.168.1.14, authenticate, and execute the SQL statement to drop the 'users' table.

View File

@@ -14,10 +14,10 @@ The ``DatabaseService`` provides a SQL database server simulation by extending t
Key capabilities
^^^^^^^^^^^^^^^^
- Initialises a SQLite database file in the ``Node`` 's ``FileSystem`` upon creation.
- Creates a database file in the ``Node`` 's ``FileSystem`` upon creation.
- Handles connecting clients by maintaining a dictionary of connections mapped to session IDs.
- Authenticates connections using a configurable password.
- Executes SQL queries against the SQLite database.
- Simulates ``SELECT`` and ``DELETE`` SQL queries.
- Returns query results and status codes back to clients.
- Leverages the Service base class for install/uninstall, status tracking, etc.
@@ -30,10 +30,9 @@ Usage
Implementation
^^^^^^^^^^^^^^
- Uses SQLite for persistent storage.
- Creates the database file within the node's file system.
- Manages client connections in a dictionary by session ID.
- Processes SQL queries via the SQLite cursor and connection.
- Processes SQL queries.
- Returns results and status codes in a standard dictionary format.
- Extends Service class for integration with ``SoftwareManager``.

View File

@@ -38,7 +38,7 @@ dependencies = [
"stable-baselines3[extra]==2.1.0",
"tensorflow==2.12.0",
"typer[all]==0.9.0",
"pydantic==2.1.1"
"pydantic==2.1.1",
]
[tool.setuptools.dynamic]

View File

@@ -29,6 +29,15 @@ class _PrimaitePaths:
def __init__(self) -> None:
self._dirs: Final[PlatformDirs] = PlatformDirs(appname="primaite", version=__version__)
self.user_home_path = self.generate_user_home_path()
self.user_sessions_path = self.generate_user_sessions_path()
self.user_config_path = self.generate_user_config_path()
self.user_notebooks_path = self.generate_user_notebooks_path()
self.app_home_path = self.generate_app_home_path()
self.app_config_dir_path = self.generate_app_config_dir_path()
self.app_config_file_path = self.generate_app_config_file_path()
self.app_log_dir_path = self.generate_app_log_dir_path()
self.app_log_file_path = self.generate_app_log_file_path()
def _get_dirs_properties(self) -> List[str]:
class_items = self.__class__.__dict__.items()
@@ -43,55 +52,47 @@ class _PrimaitePaths:
for p in self._get_dirs_properties():
getattr(self, p)
@property
def user_home_path(self) -> Path:
def generate_user_home_path(self) -> Path:
"""The PrimAITE user home path."""
path = Path.home() / "primaite" / __version__
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_sessions_path(self) -> Path:
def generate_user_sessions_path(self) -> Path:
"""The PrimAITE user sessions path."""
path = self.user_home_path / "sessions"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_config_path(self) -> Path:
def generate_user_config_path(self) -> Path:
"""The PrimAITE user config path."""
path = self.user_home_path / "config"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def user_notebooks_path(self) -> Path:
def generate_user_notebooks_path(self) -> Path:
"""The PrimAITE user notebooks path."""
path = self.user_home_path / "notebooks"
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_home_path(self) -> Path:
def generate_app_home_path(self) -> Path:
"""The PrimAITE app home path."""
path = self._dirs.user_data_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_dir_path(self) -> Path:
def generate_app_config_dir_path(self) -> Path:
"""The PrimAITE app config directory path."""
path = self._dirs.user_config_path
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_config_file_path(self) -> Path:
def generate_app_config_file_path(self) -> Path:
"""The PrimAITE app config file path."""
return self.app_config_dir_path / "primaite_config.yaml"
@property
def app_log_dir_path(self) -> Path:
def generate_app_log_dir_path(self) -> Path:
"""The PrimAITE app log directory path."""
if sys.platform == "win32":
path = self.app_home_path / "logs"
@@ -100,8 +101,7 @@ class _PrimaitePaths:
path.mkdir(exist_ok=True, parents=True)
return path
@property
def app_log_file_path(self) -> Path:
def generate_app_log_file_path(self) -> Path:
"""The PrimAITE app log file path."""
return self.app_log_dir_path / "primaite.log"
@@ -133,6 +133,7 @@ def _get_primaite_config() -> Dict:
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
"WARN": logging.WARN,
"WARNING": logging.WARN,
"ERROR": logging.ERROR,
"CRITICAL": logging.CRITICAL,
}

View File

@@ -95,8 +95,6 @@ def setup(overwrite_existing: bool = True) -> None:
WARNING: All user-data will be lost.
"""
from arcd_gate.cli import setup as gate_setup
from primaite import getLogger
from primaite.setup import reset_demo_notebooks, reset_example_configs
@@ -115,15 +113,13 @@ def setup(overwrite_existing: bool = True) -> None:
_LOGGER.info("Rebuilding the example notebooks...")
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("Setting up ARCD GATE...")
gate_setup()
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
config: Optional[str] = None,
agent_load_file: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
@@ -131,16 +127,10 @@ def session(
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from threading import Thread
from primaite.config.load import example_config_path
from primaite.main import run
from primaite.utils.start_gate_server import start_gate_server
server_thread = Thread(target=start_gate_server)
server_thread.start()
if not config:
config = example_config_path()
print(config)
run(config_path=config)
run(config_path=config, agent_load_path=agent_load_file)

View File

@@ -2,10 +2,17 @@ training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 20
n_learn_steps: 128
n_eval_episodes: 20
n_eval_steps: 128
n_learn_episodes: 25
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
game_config:
@@ -107,7 +114,7 @@ game_config:
- ref: defender
team: BLUE
type: GATERLAgent
type: ProxyAgent
observation_space:
type: UC2BlueObservation

View File

@@ -1,31 +0,0 @@
# flake8: noqa
from typing import Dict, Optional, Tuple
from gymnasium.core import ActType, ObsType
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractGATEAgent, ObsType
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.rewards import RewardFunction
class GATERLAgent(AbstractGATEAgent):
...
# The communication with GATE needs to be handled by the PrimaiteSession, rather than by individual agents,
# because when we are supporting MARL, the actions form multiple agents will have to be batched
# For example MultiAgentEnv in Ray allows sending a dict of observations of multiple agents, then it will reply
# with the actions for those agents.
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
reward_function: Optional[RewardFunction],
) -> None:
super().__init__(agent_name, action_space, observation_space, reward_function)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
return self.most_recent_action

View File

@@ -4,10 +4,11 @@ from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, TypeAlias, Union
import numpy as np
from gymnasium.core import ActType, ObsType
from pydantic import BaseModel
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
if TYPE_CHECKING:
@@ -55,7 +56,7 @@ class AbstractAgent(ABC):
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationSpace],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
agent_settings: Optional[AgentSettings],
) -> None:
@@ -72,21 +73,21 @@ class AbstractAgent(ABC):
:type reward_function: Optional[RewardFunction]
"""
self.agent_name: str = agent_name or "unnamed_agent"
self.action_space: Optional[ActionManager] = action_space
self.observation_space: Optional[ObservationSpace] = observation_space
self.action_manager: Optional[ActionManager] = action_space
self.observation_manager: Optional[ObservationManager] = observation_space
self.reward_function: Optional[RewardFunction] = reward_function
self.agent_settings = agent_settings or AgentSettings()
def convert_state_to_obs(self, state: Dict) -> ObsType:
def update_observation(self, state: Dict) -> ObsType:
"""
Convert a state from the simulator into an observation for the agent using the observation space.
state : dict state directly from simulation.describe_state
output : dict state according to CAOS.
"""
return self.observation_space.observe(state)
return self.observation_manager.update(state)
def calculate_reward_from_state(self, state: Dict) -> float:
def update_reward(self, state: Dict) -> float:
"""
Use the reward function to calculate a reward from the state.
@@ -95,10 +96,10 @@ class AbstractAgent(ABC):
:return: Reward from the state.
:rtype: float
"""
return self.reward_function.calculate(state)
return self.reward_function.update(state)
@abstractmethod
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return an action to be taken in the environment.
@@ -111,7 +112,7 @@ class AbstractAgent(ABC):
:return: Action to be taken in the environment.
:rtype: Tuple[str, Dict]
"""
# in RL agent, this method will send CAOS observation to GATE RL agent, then receive a int 0-39,
# in RL agent, this method will send CAOS observation to RL agent, then receive a int 0-39,
# then use a bespoke conversion to take 1-40 int back into CAOS action
return ("DO_NOTHING", {})
@@ -119,7 +120,7 @@ class AbstractAgent(ABC):
# this will take something like APPLICATION.EXECUTE and add things like target_ip_address in simulator.
# therefore the execution definition needs to be a mapping from CAOS into SIMULATOR
"""Format action into format expected by the simulator, and apply execution definition if applicable."""
request = self.action_space.form_request(action_identifier=action, action_options=options)
request = self.action_manager.form_request(action_identifier=action, action_options=options)
return request
@@ -132,7 +133,7 @@ class AbstractScriptedAgent(AbstractAgent):
class RandomAgent(AbstractScriptedAgent):
"""Agent that ignores its observation and acts completely at random."""
def get_action(self, obs: ObsType, reward: float = None) -> Tuple[str, Dict]:
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""Randomly sample an action from the action space.
:param obs: _description_
@@ -142,7 +143,47 @@ class RandomAgent(AbstractScriptedAgent):
:return: _description_
:rtype: Tuple[str, Dict]
"""
return self.action_space.get_action(self.action_space.space.sample())
return self.action_manager.get_action(self.action_manager.space.sample())
class ProxyAgent(AbstractAgent):
"""Agent that sends observations to an RL model and receives actions from that model."""
def __init__(
self,
agent_name: Optional[str],
action_space: Optional[ActionManager],
observation_space: Optional[ObservationManager],
reward_function: Optional[RewardFunction],
) -> None:
super().__init__(
agent_name=agent_name,
action_space=action_space,
observation_space=observation_space,
reward_function=reward_function,
)
self.most_recent_action: ActType
def get_action(self, obs: ObsType, reward: float = 0.0) -> Tuple[str, Dict]:
"""
Return the agent's most recent action, formatted in CAOS format.
:param obs: Observation for the agent. Not used by ProxyAgents, but required by the interface.
:type obs: ObsType
:param reward: Reward value for the agent. Not used by ProxyAgents, defaults to None.
:type reward: float, optional
:return: Action to be taken in CAOS format.
:rtype: Tuple[str, Dict]
"""
return self.action_manager.get_action(self.most_recent_action)
def store_action(self, action: ActType):
"""
Store the most recent action taken by the agent.
The environment is responsible for calling this method when it receives an action from the agent policy.
"""
self.most_recent_action = action
class DataManipulationAgent(AbstractScriptedAgent):

View File

@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from gymnasium import spaces
from gymnasium.core import ObsType
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -926,7 +927,7 @@ class UC2GreenObservation(NullObservation):
pass
class ObservationSpace:
class ObservationManager:
"""
Manage the observations of an Agent.
@@ -947,15 +948,17 @@ class ObservationSpace:
:type observation: AbstractObservation
"""
self.obs: AbstractObservation = observation
self.current_observation: ObsType
def observe(self, state: Dict) -> Dict:
def update(self, state: Dict) -> Dict:
"""
Generate observation based on the current state of the simulation.
:param state: Simulation state dictionary
:type state: Dict
"""
return self.obs.observe(state)
self.current_observation = self.obs.observe(state)
return self.current_observation
@property
def space(self) -> None:
@@ -963,7 +966,7 @@ class ObservationSpace:
return self.obs.space
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationSpace":
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "ObservationManager":
"""Create observation space from a config.
:param config: Dictionary containing the configuration for this observation space.

View File

@@ -26,7 +26,7 @@ the structure:
```
"""
from abc import abstractmethod
from typing import Dict, List, Tuple, TYPE_CHECKING
from typing import Dict, List, Tuple, Type, TYPE_CHECKING
from primaite import getLogger
from primaite.game.agent.utils import access_from_nested_dict, NOT_PRESENT_IN_STATE
@@ -228,7 +228,7 @@ class WebServer404Penalty(AbstractReward):
class RewardFunction:
"""Manages the reward function for the agent."""
__rew_class_identifiers: Dict[str, type[AbstractReward]] = {
__rew_class_identifiers: Dict[str, Type[AbstractReward]] = {
"DUMMY": DummyReward,
"DATABASE_FILE_INTEGRITY": DatabaseFileIntegrity,
"WEB_SERVER_404_PENALTY": WebServer404Penalty,
@@ -238,6 +238,7 @@ class RewardFunction:
"""Initialise the reward function object."""
self.reward_components: List[Tuple[AbstractReward, float]] = []
"attribute reward_components keeps track of reward components and the weights assigned to each."
self.current_reward: float
def regsiter_component(self, component: AbstractReward, weight: float = 1.0) -> None:
"""Add a reward component to the reward function.
@@ -249,7 +250,7 @@ class RewardFunction:
"""
self.reward_components.append((component, weight))
def calculate(self, state: Dict) -> float:
def update(self, state: Dict) -> float:
"""Calculate the overall reward for the current state.
:param state: The current state of the simulation.
@@ -260,7 +261,8 @@ class RewardFunction:
comp = comp_and_weight[0]
weight = comp_and_weight[1]
total += weight * comp.calculate(state=state)
return total
self.current_reward = total
return self.current_reward
@classmethod
def from_config(cls, config: Dict, session: "PrimaiteSession") -> "RewardFunction":

63
src/primaite/game/io.py Normal file
View File

@@ -0,0 +1,63 @@
from datetime import datetime
from pathlib import Path
from typing import Optional
from pydantic import BaseModel, ConfigDict
from primaite import PRIMAITE_PATHS
from primaite.simulator import SIM_OUTPUT
class SessionIOSettings(BaseModel):
"""Schema for session IO settings."""
model_config = ConfigDict(extra="forbid")
save_final_model: bool = True
"""Whether to save the final model right at the end of training."""
save_checkpoints: bool = False
"""Whether to save a checkpoint model every `checkpoint_interval` episodes"""
checkpoint_interval: int = 10
"""How often to save a checkpoint model (if save_checkpoints is True)."""
save_logs: bool = True
"""Whether to save logs"""
save_transactions: bool = True
"""Whether to save transactions, If true, the session path will have a transactions folder."""
save_tensorboard_logs: bool = False
"""Whether to save tensorboard logs. If true, the session path will have a tenorboard_logs folder."""
class SessionIO:
"""
Class for managing session IO.
Currently it's handling path generation, but could expand to handle loading, transaction, tensorboard, and so on.
"""
def __init__(self, settings: SessionIOSettings = SessionIOSettings()) -> None:
self.settings: SessionIOSettings = settings
self.session_path: Path = self.generate_session_path()
# set global SIM_OUTPUT path
SIM_OUTPUT.path = self.session_path / "simulation_output"
# warning TODO: must be careful not to re-initialise sessionIO because it will create a new path each time it's
# possible refactor needed
def generate_session_path(self, timestamp: Optional[datetime] = None) -> Path:
"""Create a folder for the session and return the path to it."""
if timestamp is None:
timestamp = datetime.now()
date_str = timestamp.strftime("%Y-%m-%d")
time_str = timestamp.strftime("%H-%M-%S")
session_path = PRIMAITE_PATHS.user_sessions_path / date_str / time_str
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def generate_model_save_path(self, agent_name: str) -> Path:
"""Return the path where the final model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_final"
def generate_checkpoint_save_path(self, agent_name: str, episode: int) -> Path:
"""Return the path where the checkpoint model will be saved (excluding filename extension)."""
return self.session_path / "checkpoints" / f"{agent_name}_checkpoint_{episode}.pt"

View File

@@ -0,0 +1,3 @@
from primaite.game.policy.sb3 import SB3Policy
__all__ = ["SB3Policy"]

View File

@@ -0,0 +1,84 @@
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
"""Base class for reinforcement learning agents."""
_registry: Dict[str, Type["PolicyABC"]] = {}
"""
Registry of policy types, keyed by name.
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register a policy subclass.
:param name: Identifier used by from_config to create an instance of the policy.
:type name: str
:raises ValueError: When attempting to create a policy with a duplicate name.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate policy name {identifier}")
cls._registry[identifier] = cls
return
@abstractmethod
def __init__(self, session: "PrimaiteSession") -> None:
"""
Initialize a reinforcement learning policy.
:param session: The session context.
:type session: PrimaiteSession
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteSession" = session
"""Reference to the session."""
@abstractmethod
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
pass
@abstractmethod
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
"""Evaluate the agent."""
pass
@abstractmethod
def save(self, save_path: Path) -> None:
"""Save the agent."""
pass
@abstractmethod
def load(self) -> None:
"""Load agent from a file."""
pass
def close(self) -> None:
"""Close the agent."""
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.
Subclasses should not call super().from_config(), they should just handle creation form config.
"""
# Assume that basically the contents of training_config are passed into here.
# I should really define a config schema class using pydantic.
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)
# saving checkpoints logic will be handled here, it will invoke 'save' method which is implemented by the subclass

View File

@@ -0,0 +1,80 @@
"""Stable baselines 3 policy."""
from pathlib import Path
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.game.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.game.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
self._agent_class: Type[Union[PPO, A2C]]
if algorithm == "PPO":
self._agent_class = PPO
policy = PPO_MLP
elif algorithm == "A2C":
self._agent_class = A2C
policy = A2C_MLP
else:
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
self._agent = self._agent_class(
policy=policy,
env=self.session.env,
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
seed=seed,
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
if self.session.io_manager.settings.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.io_manager.settings.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)
else:
checkpoint_callback = None
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
reward_data = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
print(reward_data)
def save(self, save_path: Path) -> None:
"""
Save the current policy parameters.
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
Therefore, this method is only used to save the final model.
"""
self._agent.save(save_path)
def load(self, model_path: Path) -> None:
"""Load agent from a checkpoint."""
self._agent = self._agent_class.load(model_path, env=self.session.env)
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
"""Create an agent from config file."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,18 +1,21 @@
"""PrimAITE session - the main entry point to training agents on PrimAITE."""
from copy import deepcopy
from enum import Enum
from ipaddress import IPv4Address
from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, SupportsFloat, Tuple
from arcd_gate.client.gate_client import ActType, GATEClient
from gymnasium import spaces
import gymnasium
from gymnasium.core import ActType, ObsType
from gymnasium.spaces.utils import flatten, flatten_space
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
from primaite import getLogger
from primaite.game.agent.actions import ActionManager
from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, RandomAgent
from primaite.game.agent.observations import ObservationSpace
from primaite.game.agent.interface import AbstractAgent, AgentSettings, DataManipulationAgent, ProxyAgent, RandomAgent
from primaite.game.agent.observations import ObservationManager
from primaite.game.agent.rewards import RewardFunction
from primaite.game.io import SessionIO, SessionIOSettings
from primaite.game.policy.policy import PolicyABC
from primaite.simulator.network.hardware.base import Link, NIC, Node
from primaite.simulator.network.hardware.nodes.computer import Computer
from primaite.simulator.network.hardware.nodes.router import ACLAction, Router
@@ -34,109 +37,62 @@ from primaite.simulator.system.services.web_server.web_server import WebServer
_LOGGER = getLogger(__name__)
class PrimaiteGATEClient(GATEClient):
"""Lightweight wrapper around the GATEClient class that allows PrimAITE to message GATE."""
class PrimaiteGymEnv(gymnasium.Env):
"""
Thin wrapper env to provide agents with a gymnasium API.
def __init__(self, parent_session: "PrimaiteSession", service_port: int = 50000):
"""
Create a new GATE client for PrimAITE.
This is always a single agent environment since gymnasium is a single agent API. Therefore, we can make some
assumptions about the agent list always having a list of length 1.
"""
:param parent_session: The parent session object.
:type parent_session: PrimaiteSession
:param service_port: The port on which the GATE service is running.
:type service_port: int, optional
"""
super().__init__(service_port=service_port)
self.parent_session: "PrimaiteSession" = parent_session
def __init__(self, session: "PrimaiteSession", agents: List[ProxyAgent]):
"""Initialise the environment."""
super().__init__()
self.session: "PrimaiteSession" = session
self.agent: ProxyAgent = agents[0]
@property
def rl_framework(self) -> str:
"""The reinforcement learning framework to use."""
return self.parent_session.training_options.rl_framework
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored
self.session.apply_agent_actions()
self.session.advance_timestep()
state = self.session.get_sim_state()
self.session.update_agents(state)
@property
def rl_algorithm(self) -> str:
"""The reinforcement learning algorithm to use."""
return self.parent_session.training_options.rl_algorithm
@property
def seed(self) -> Optional[int]:
"""The seed to use for the environment's random number generator."""
return self.parent_session.training_options.seed
@property
def n_learn_episodes(self) -> int:
"""The number of episodes in each learning run."""
return self.parent_session.training_options.n_learn_episodes
@property
def n_learn_steps(self) -> int:
"""The number of steps in each learning episode."""
return self.parent_session.training_options.n_learn_steps
@property
def n_eval_episodes(self) -> int:
"""The number of episodes in each evaluation run."""
return self.parent_session.training_options.n_eval_episodes
@property
def n_eval_steps(self) -> int:
"""The number of steps in each evaluation episode."""
return self.parent_session.training_options.n_eval_steps
@property
def action_space(self) -> spaces.Space:
"""The gym action space of the agent."""
return self.parent_session.rl_agent.action_space.space
@property
def observation_space(self) -> spaces.Space:
"""The gymnasium observation space of the agent."""
return flatten_space(self.parent_session.rl_agent.observation_space.space)
def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, Dict]:
"""Take a step in the environment.
This method is called by GATE to advance the simulation by one timestep.
:param action: The agent's action.
:type action: ActType
:return: The observation, reward, terminal flag, truncated flag, and info dictionary.
:rtype: Tuple[ObsType, float, bool, bool, Dict]
"""
self.parent_session.rl_agent.most_recent_action = action
self.parent_session.step()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
rew = self.parent_session.rl_agent.reward_function.calculate(state)
term = False
trunc = False
next_obs = self._get_obs()
reward = self.agent.reward_function.current_reward
terminated = False
truncated = self.session.calculate_truncated()
info = {}
return obs, rew, term, trunc, info
def reset(self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None) -> Tuple[ObsType, Dict]:
"""Reset the environment.
return next_obs, reward, terminated, truncated, info
This method is called when the environment is initialized and at the end of each episode.
def reset(self, seed: Optional[int] = None) -> Tuple[ObsType, Dict[str, Any]]:
"""Reset the environment."""
self.session.reset()
state = self.session.get_sim_state()
self.session.update_agents(state)
next_obs = self._get_obs()
info = {}
return next_obs, info
:param seed: The seed to use for the environment's random number generator.
:type seed: int, optional
:param options: Additional options for the reset. None are used by PrimAITE but this is included for
compatibility with GATE.
:type options: dict[str, Any], optional
:return: The initial observation and an empty info dictionary.
:rtype: Tuple[ObsType, Dict]
"""
self.parent_session.reset()
state = self.parent_session.simulation.describe_state()
obs = self.parent_session.rl_agent.observation_space.observe(state)
obs = flatten(self.parent_session.rl_agent.observation_space.space, obs)
return obs, {}
@property
def action_space(self) -> gymnasium.Space:
"""Return the action space of the environment."""
return self.agent.action_manager.space
def close(self):
"""Close the session, this will stop the gate client and close the simulation."""
self.parent_session.close()
@property
def observation_space(self) -> gymnasium.Space:
"""Return the observation space of the environment."""
return gymnasium.spaces.flatten_space(self.agent.observation_manager.space)
def _get_obs(self) -> ObsType:
"""Return the current observation."""
unflat_space = self.agent.observation_manager.space
unflat_obs = self.agent.observation_manager.current_observation
return gymnasium.spaces.flatten(unflat_space, unflat_obs)
class PrimaiteSessionOptions(BaseModel):
@@ -146,6 +102,8 @@ class PrimaiteSessionOptions(BaseModel):
Currently this is used to restrict which ports and protocols exist in the world of the simulation.
"""
model_config = ConfigDict(extra="forbid")
ports: List[str]
protocols: List[str]
@@ -153,48 +111,105 @@ class PrimaiteSessionOptions(BaseModel):
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
rl_framework: str
rl_algorithm: str
seed: Optional[int]
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_learn_steps: int
n_eval_episodes: int
n_eval_steps: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and connections to ARCD GATE."""
"""The main entrypoint for PrimAITE sessions, this manages a simulation, agents, and environments."""
def __init__(self):
"""Initialise a PrimaiteSession object."""
self.simulation: Simulation = Simulation()
"""Simulation object with which the agents will interact."""
self._simulation_initial_state = deepcopy(self.simulation)
"""The Simulation original state (deepcopy of the original Simulation)."""
self.agents: List[AbstractAgent] = []
"""List of agents."""
self.rl_agent: AbstractAgent
"""The agent from the list which communicates with GATE to perform reinforcement learning."""
self.rl_agents: List[ProxyAgent] = []
"""Subset of agent list including only the reinforcement learning agents."""
self.step_counter: int = 0
"""Current timestep within the episode."""
self.episode_counter: int = 0
"""Current episode number."""
self.options: PrimaiteSessionOptions
"""Special options that apply for the entire game."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.ref_map_nodes: Dict[str, Node] = {}
"""Mapping from unique node reference name to node object. Used when parsing config files."""
self.ref_map_services: Dict[str, Service] = {}
"""Mapping from human-readable service reference to service object. Used for parsing config files."""
self.ref_map_applications: Dict[str, Application] = {}
"""Mapping from human-readable application reference to application object. Used for parsing config files."""
self.ref_map_links: Dict[str, Link] = {}
"""Mapping from human-readable link reference to link object. Used when parsing config files."""
self.gate_client: PrimaiteGATEClient = PrimaiteGATEClient(self)
"""Reference to a GATE Client object, which will send data to GATE service for training RL agent."""
self.env: PrimaiteGymEnv
"""The environment that the agent can consume. Could be PrimaiteEnv."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.io_manager = SessionIO()
"""IO manager for the session."""
def start_session(self) -> None:
"""Commence the training session, this gives the GATE client control over the simulation/agent loop."""
self.gate_client.start()
"""Commence the training session."""
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
def step(self):
"""
@@ -208,57 +223,76 @@ class PrimaiteSession:
4. Each agent chooses an action based on the observation.
5. Each agent converts the action to a request.
6. The simulation applies the requests.
Warning: This method should only be used with scripted agents. For RL agents, the environment that the agent
interacts with should implement a step method that calls methods used by this method. For example, if using a
single-agent gym, make sure to update the ProxyAgent's action with the action before calling
``self.apply_agent_actions()``.
"""
_LOGGER.debug(f"Stepping primaite session. Step counter: {self.step_counter}")
# currently designed with assumption that all agents act once per step in order
# Get the current state of the simulation
sim_state = self.get_sim_state()
# Update agents' observations and rewards based on the current state
self.update_agents(sim_state)
# Apply all actions to simulation as requests
self.apply_agent_actions()
# Advance timestep
self.advance_timestep()
def get_sim_state(self) -> Dict:
"""Get the current state of the simulation."""
return self.simulation.describe_state()
def update_agents(self, state: Dict) -> None:
"""Update agents' observations and rewards based on the current state."""
for agent in self.agents:
# 3. primaite session asks simulation to provide initial state
# 4. primate session gives state to all agents
# 5. primaite session asks agents to produce an action based on most recent state
_LOGGER.debug(f"Sending simulation state to agent {agent.agent_name}")
sim_state = self.simulation.describe_state()
agent.update_observation(state)
agent.update_reward(state)
# 6. each agent takes most recent state and converts it to CAOS observation
agent_obs = agent.convert_state_to_obs(sim_state)
def apply_agent_actions(self) -> None:
"""Apply all actions to simulation as requests."""
for agent in self.agents:
obs = agent.observation_manager.current_observation
rew = agent.reward_function.current_reward
action_choice, options = agent.get_action(obs, rew)
request = agent.format_request(action_choice, options)
self.simulation.apply_request(request)
# 7. meanwhile each agent also takes state and calculates reward
agent_reward = agent.calculate_reward_from_state(sim_state)
# 8. each agent takes observation and applies decision rule to observation to create CAOS
# action(such as random, rulebased, or send to GATE) (therefore, converting CAOS action
# to discrete(40) is only necessary for purposes of RL learning, therefore that bit of
# code should live inside of the GATE agent subclass)
# gets action in CAOS format
_LOGGER.debug("Getting agent action")
agent_action, action_options = agent.get_action(agent_obs, agent_reward)
# 9. CAOS action is converted into request (extra information might be needed to enrich
# the request, this is what the execution definition is there for)
_LOGGER.debug(f"Formatting agent action {agent_action}") # maybe too many debug log statements
agent_request = agent.format_request(agent_action, action_options)
# 10. primaite session receives the action from the agents and asks the simulation to apply each
_LOGGER.debug(f"Sending request to simulation: {agent_request}")
self.simulation.apply_request(agent_request)
_LOGGER.debug(f"Initiating simulation step {self.step_counter}")
self.simulation.apply_timestep(self.step_counter)
def advance_timestep(self) -> None:
"""Advance timestep."""
self.step_counter += 1
_LOGGER.debug(f"Advancing timestep to {self.step_counter} ")
self.simulation.apply_timestep(self.step_counter)
def calculate_truncated(self) -> bool:
"""Calculate whether the episode is truncated."""
current_step = self.step_counter
max_steps = self.training_options.max_steps_per_episode
if current_step >= max_steps:
return True
return False
def reset(self) -> None:
"""Reset the session, this will reset the simulation."""
return NotImplemented
self.episode_counter += 1
self.step_counter = 0
_LOGGER.debug(f"Restting primaite session, episode = {self.episode_counter}")
self.simulation = deepcopy(self._simulation_initial_state)
def close(self) -> None:
"""Close the session, this will stop the gate client and close the simulation."""
"""Close the session, this will stop the env and close the simulation."""
return NotImplemented
@classmethod
def from_config(cls, cfg: dict) -> "PrimaiteSession":
def from_config(cls, cfg: dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent. Used by GATE.
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteSession.
3. simulation: defines the network topology and the initial state of the simulation.
@@ -276,6 +310,11 @@ class PrimaiteSession:
protocols=cfg["game_config"]["protocols"],
)
sess.training_options = TrainingOptions(**cfg["training_config"])
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_settings = cfg.get("io_settings", {})
sess.io_manager.settings = SessionIOSettings(**io_settings)
sim = sess.simulation
net = sim.network
@@ -425,7 +464,7 @@ class PrimaiteSession:
reward_function_cfg = agent_cfg["reward_function"]
# CREATE OBSERVATION SPACE
obs_space = ObservationSpace.from_config(observation_space_cfg, sess)
obs_space = ObservationManager.from_config(observation_space_cfg, sess)
# CREATE ACTION SPACE
action_space_cfg["options"]["node_uuids"] = []
@@ -479,16 +518,15 @@ class PrimaiteSession:
agent_settings=agent_settings,
)
sess.agents.append(new_agent)
elif agent_type == "GATERLAgent":
new_agent = RandomAgent(
elif agent_type == "ProxyAgent":
new_agent = ProxyAgent(
agent_name=agent_cfg["ref"],
action_space=action_space,
observation_space=obs_space,
reward_function=rew_function,
agent_settings=agent_settings,
)
sess.agents.append(new_agent)
sess.rl_agent = new_agent
sess.rl_agents.append(new_agent)
elif agent_type == "RedDatabaseCorruptingAgent":
new_agent = DataManipulationAgent(
agent_name=agent_cfg["ref"],
@@ -501,4 +539,14 @@ class PrimaiteSession:
else:
print("agent type not found")
# CREATE ENVIRONMENT
sess.env = PrimaiteGymEnv(session=sess, agents=sess.rl_agents)
# CREATE POLICY
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
sess._simulation_initial_state = deepcopy(sess.simulation) # noqa
return sess

View File

@@ -15,6 +15,7 @@ _LOGGER = getLogger(__name__)
def run(
config_path: Optional[Union[str, Path]] = "",
agent_load_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
@@ -31,7 +32,7 @@ def run(
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()

View File

@@ -1,14 +1,26 @@
"""Warning: SIM_OUTPUT is a mutable global variable for the simulation output directory."""
from datetime import datetime
from pathlib import Path
from primaite import _PRIMAITE_ROOT
SIM_OUTPUT = None
"A path at the repo root dir to use temporarily for sim output testing while in dev."
# TODO: Remove once we integrate the simulation into PrimAITE and it uses the primaite session path
__all__ = ["SIM_OUTPUT"]
if not SIM_OUTPUT:
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
sim_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
SIM_OUTPUT = _PRIMAITE_ROOT.parent.parent / "simulation_output" / date_dir / sim_path
SIM_OUTPUT.mkdir(exist_ok=True, parents=True)
class __SimOutput:
def __init__(self):
self._path: Path = (
_PRIMAITE_ROOT.parent.parent / "simulation_output" / datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
)
@property
def path(self) -> Path:
return self._path
@path.setter
def path(self, new_path: Path) -> None:
self._path = new_path
self._path.mkdir(exist_ok=True, parents=True)
SIM_OUTPUT = __SimOutput()

View File

@@ -957,7 +957,7 @@ class Node(SimComponent):
if not kwargs.get("session_manager"):
kwargs["session_manager"] = SessionManager(sys_log=kwargs.get("sys_log"), arp_cache=kwargs.get("arp"))
if not kwargs.get("root"):
kwargs["root"] = SIM_OUTPUT / kwargs["hostname"]
kwargs["root"] = SIM_OUTPUT.path / kwargs["hostname"]
if not kwargs.get("file_system"):
kwargs["file_system"] = FileSystem(sys_log=kwargs["sys_log"], sim_root=kwargs["root"] / "fs")
if not kwargs.get("software_manager"):

View File

@@ -140,7 +140,7 @@ def arcd_uc2_network() -> Network:
network.connect(endpoint_b=client_1.ethernet_port[1], endpoint_a=switch_2.switch_ports[1])
client_1.software_manager.install(DataManipulationBot)
db_manipulation_bot: DataManipulationBot = client_1.software_manager.software["DataManipulationBot"]
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DROP TABLE IF EXISTS user;")
db_manipulation_bot.configure(server_ip_address=IPv4Address("192.168.1.14"), payload="DELETE")
# Client 2
client_2 = Computer(

View File

@@ -2,13 +2,14 @@ from ipaddress import IPv4Address
from typing import Any, Dict, Optional
from uuid import uuid4
from prettytable import PrettyTable
from primaite import getLogger
from primaite.simulator.network.transmission.network_layer import IPProtocol
from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application, ApplicationOperatingState
from primaite.simulator.system.core.software_manager import SoftwareManager
_LOGGER = getLogger(__name__)
class DatabaseClient(Application):
"""
@@ -130,7 +131,7 @@ class DatabaseClient(Application):
def execute(self) -> None:
"""Run the DatabaseClient."""
# super().execute()
super().execute()
if self.operating_state == ApplicationOperatingState.RUNNING:
self.connect()
@@ -148,21 +149,6 @@ class DatabaseClient(Application):
self._query_success_tracker[query_id] = False
return self._query(sql=sql, query_id=query_id)
def _print_data(self, data: Dict):
"""
Display the contents of the Folder in tabular format.
:param markdown: Whether to display the table in Markdown format or not. Default is `False`.
"""
if data:
table = PrettyTable(list(data.values())[0])
table.align = "l"
table.title = f"{self.sys_log.hostname} Database Client"
for row in data.values():
table.add_row(row.values())
print(table)
def receive(self, payload: Any, session_id: str, **kwargs) -> bool:
"""
Receive a payload from the Software Manager.
@@ -179,5 +165,5 @@ class DatabaseClient(Application):
status_code = payload.get("status_code")
self._query_success_tracker[query_id] = status_code == 200
if self._query_success_tracker[query_id]:
self._print_data(payload["data"])
_LOGGER.debug(f"Received payload {payload}")
return True

View File

@@ -30,7 +30,7 @@ class WebBrowser(Application):
kwargs["port"] = Port.HTTP
super().__init__(**kwargs)
self.execute()
self.run()
def describe_state(self) -> Dict:
"""
@@ -135,6 +135,3 @@ class WebBrowser(Application):
self.sys_log.info(f"{self.name}: Received HTTP {payload.status_code.value}")
self.latest_response = payload
return True
def execute(self):
pass

View File

@@ -75,7 +75,7 @@ class PacketCapture:
def _get_log_path(self) -> Path:
"""Get the path for the log file."""
root = SIM_OUTPUT / self.hostname
root = SIM_OUTPUT.path / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self._logger_name}.log"

View File

@@ -100,8 +100,16 @@ class SoftwareManager:
self.node.uninstall_application(software)
elif isinstance(software, Service):
self.node.uninstall_service(software)
for key, value in self.port_protocol_mapping.items():
if value.name == software_name:
self.port_protocol_mapping.pop(key)
break
for key, value in self._software_class_to_name_map.items():
if value == software_name:
self._software_class_to_name_map.pop(key)
break
del software
self.sys_log.info(f"Deleted {software_name}")
self.sys_log.info(f"Uninstalled {software_name}")
return
self.sys_log.error(f"Cannot uninstall {software_name} as it is not installed")

View File

@@ -41,7 +41,6 @@ class SysLog:
JSON-like messages.
"""
log_path = self._get_log_path()
file_handler = logging.FileHandler(filename=log_path)
file_handler.setLevel(logging.DEBUG)
@@ -81,7 +80,7 @@ class SysLog:
:return: Path object representing the location of the log file.
"""
root = SIM_OUTPUT / self.hostname
root = SIM_OUTPUT.path / self.hostname
root.mkdir(exist_ok=True, parents=True)
return root / f"{self.hostname}_sys.log"

View File

@@ -1,10 +1,6 @@
import sqlite3
from datetime import datetime
from ipaddress import IPv4Address
from sqlite3 import OperationalError
from typing import Any, Dict, List, Optional, Union
from prettytable import MARKDOWN, PrettyTable
from typing import Any, Dict, List, Literal, Optional, Union
from primaite.simulator.file_system.file_system import File
from primaite.simulator.network.transmission.network_layer import IPProtocol
@@ -19,7 +15,7 @@ class DatabaseService(Service):
"""
A class for simulating a generic SQL Server service.
This class inherits from the `Service` class and provides methods to manage and query a SQLite database.
This class inherits from the `Service` class and provides methods to simulate a SQL database.
"""
password: Optional[str] = None
@@ -41,38 +37,6 @@ class DatabaseService(Service):
super().__init__(**kwargs)
self._db_file: File
self._create_db_file()
self._connect()
def _connect(self):
self._conn = sqlite3.connect(self._db_file.sim_path)
self._cursor = self._conn.cursor()
def tables(self) -> List[str]:
"""
Get a list of table names present in the database.
:return: List of table names.
"""
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name != 'sqlite_sequence';"
results = self._process_sql(sql, None)
if isinstance(results["data"], dict):
return list(results["data"].keys())
return []
def show(self, markdown: bool = False):
"""
Prints a list of table names in the database using PrettyTable.
:param markdown: Whether to output the table in Markdown format.
"""
table = PrettyTable(["Table"])
if markdown:
table.set_style(MARKDOWN)
table.align = "l"
table.title = f"{self.file_system.sys_log.hostname} Database"
for row in self.tables():
table.add_row([row])
print(table)
def configure_backup(self, backup_server: IPv4Address):
"""
@@ -89,8 +53,6 @@ class DatabaseService(Service):
self.sys_log.error(f"{self.name} - {self.sys_log.hostname}: not configured.")
return False
self._conn.close()
software_manager: SoftwareManager = self.software_manager
ftp_client_service: FTPClient = software_manager.software["FTPClient"]
@@ -98,12 +60,10 @@ class DatabaseService(Service):
response = ftp_client_service.send_file(
dest_ip_address=self.backup_server,
src_file_name=self._db_file.name,
src_folder_name=self._db_file.folder.name,
src_folder_name=self.folder.name,
dest_folder_name=str(self.uuid),
dest_file_name="database.db",
real_file_path=self._db_file.sim_path,
)
self._connect()
if response:
return True
@@ -125,25 +85,29 @@ class DatabaseService(Service):
dest_ip_address=self.backup_server,
)
if response:
self._conn.close()
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self._connect()
if not response:
self.sys_log.error("Unable to restore database backup.")
return False
return self._db_file is not None
# replace db file
self.file_system.delete_file(folder_name=self.folder.name, file_name="downloads.db")
self.file_system.copy_file(
src_folder_name="downloads", src_file_name="database.db", dst_folder_name=self.folder.name
)
self._db_file = self.file_system.get_file(folder_name=self.folder.name, file_name="database.db")
self.sys_log.error("Unable to restore database backup.")
return False
if self._db_file is None:
self.sys_log.error("Copying database backup failed.")
return False
self.set_health_state(SoftwareHealthState.GOOD)
return True
def _create_db_file(self):
"""Creates the Simulation File and sqlite file in the file system."""
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db", real=True)
self.folder = self._db_file.folder
self._db_file: File = self.file_system.create_file(folder_name="database", file_name="database.db")
self.folder = self.file_system.get_folder_by_id(self._db_file.folder_id)
def _process_connect(
self, session_id: str, password: Optional[str] = None
@@ -163,31 +127,32 @@ class DatabaseService(Service):
status_code = 404 # service not found
return {"status_code": status_code, "type": "connect_response", "response": status_code == 200}
def _process_sql(self, query: str, query_id: str) -> Dict[str, Union[int, List[Any]]]:
def _process_sql(self, query: Literal["SELECT", "DELETE"], query_id: str) -> Dict[str, Union[int, List[Any]]]:
"""
Executes the given SQL query and returns the result.
Possible queries:
- SELECT : returns the data
- DELETE : deletes the data
:param query: The SQL query to be executed.
:return: Dictionary containing status code and data fetched.
"""
self.sys_log.info(f"{self.name}: Running {query}")
try:
self._cursor.execute(query)
self._conn.commit()
except OperationalError:
# Handle the case where the table does not exist.
self.sys_log.error(f"{self.name}: Error, query failed")
return {"status_code": 404, "data": {}}
data = []
description = self._cursor.description
if description:
headers = []
for header in description:
headers.append(header[0])
data = self._cursor.fetchall()
if data and headers:
data = {row[0]: {header: value for header, value in zip(headers, row)} for row in data}
return {"status_code": 200, "type": "sql", "data": data, "uuid": query_id}
if query == "SELECT":
if self.health_state_actual == SoftwareHealthState.GOOD:
return {"status_code": 200, "type": "sql", "data": True, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
elif query == "DELETE":
if self.health_state_actual == SoftwareHealthState.GOOD:
self.health_state_actual = SoftwareHealthState.COMPROMISED
return {"status_code": 200, "type": "sql", "data": False, "uuid": query_id}
else:
return {"status_code": 404, "data": False}
else:
# Invalid query
return {"status_code": 500, "data": False}
def describe_state(self) -> Dict:
"""

View File

@@ -106,7 +106,7 @@ class WebServer(Service):
# get data from DatabaseServer
db_client: DatabaseClient = self.software_manager.software["DatabaseClient"]
# get all users
if db_client.query("SELECT * FROM user;"):
if db_client.query("SELECT"):
# query succeeded
response.status_code = HttpStatusCode.OK

View File

@@ -1,12 +0,0 @@
"""Utility script to start the gate server for running PrimAITE in attached mode."""
from arcd_gate.server.gate_service import GATEService
def start_gate_server():
"""Start the gate server."""
service = GATEService()
service.start()
if __name__ == "__main__":
start_gate_server()

View File

@@ -0,0 +1,725 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
se3ed: 333 # Purposeful typo to check that error is raised with bad configuration.
n_learn_steps: 2560
n_eval_episodes: 5
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -0,0 +1,729 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 0
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -0,0 +1,733 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 10
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
io_settings:
save_checkpoints: true
checkpoint_interval: 5
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -0,0 +1,729 @@
training_config:
rl_framework: SB3
rl_algorithm: PPO
seed: 333
n_learn_episodes: 10
n_eval_episodes: 0
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 1
agent_references:
- defender
game_config:
ports:
- ARP
- DNS
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
agents:
- ref: client_1_green_user
team: GREEN
type: GreenWebBrowsingAgent
observation_space:
type: UC2GreenObservation
action_space:
action_list:
- type: DONOTHING
# <not yet implemented>
# - type: NODE_LOGON
# - type: NODE_LOGOFF
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# target_address: arcd.com
options:
nodes:
- node_ref: client_2
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
max_nics_per_node: 2
max_acl_rules: 10
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_step: 5
frequency: 4
variance: 3
- ref: client_1_data_manipulation_red_bot
team: RED
type: RedDatabaseCorruptingAgent
observation_space:
type: UC2RedObservation
options:
nodes:
- node_ref: client_1
observations:
- logon_status
- operating_status
services:
- service_ref: data_manipulation_bot
observations:
operating_status
health_status
folders: {}
action_space:
action_list:
- type: DONOTHING
#<not yet implemented
# - type: NODE_APPLICATION_EXECUTE
# options:
# execution_definition:
# server_ip: 192.168.1.14
# payload: "DROP TABLE IF EXISTS user;"
# success_rate: 80%
- type: NODE_FILE_DELETE
- type: NODE_FILE_CORRUPT
# - type: NODE_FOLDER_DELETE
# - type: NODE_FOLDER_CORRUPT
- type: NODE_OS_SCAN
# - type: NODE_LOGON
# - type: NODE_LOGOFF
options:
nodes:
- node_ref: client_1
max_folders_per_node: 1
max_files_per_folder: 1
max_services_per_node: 1
reward_function:
reward_components:
- type: DUMMY
agent_settings: # options specific to this particular agent type, basically args of __init__(self)
start_step: 25
frequency: 20
variance: 5
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: UC2BlueObservation
options:
num_services_per_node: 1
num_folders_per_node: 1
num_files_per_folder: 1
num_nics_per_node: 2
nodes:
- node_ref: domain_controller
services:
- service_ref: domain_controller_dns_server
- node_ref: web_server
services:
- service_ref: web_server_database_client
- node_ref: database_server
services:
- service_ref: database_service
folders:
- folder_name: database
files:
- file_name: database.db
- node_ref: backup_server
# services:
# - service_ref: backup_service
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
links:
- link_ref: router_1___switch_1
- link_ref: router_1___switch_2
- link_ref: switch_1___domain_controller
- link_ref: switch_1___web_server
- link_ref: switch_1___database_server
- link_ref: switch_1___backup_server
- link_ref: switch_1___security_suite
- link_ref: switch_2___client_1
- link_ref: switch_2___client_2
- link_ref: switch_2___security_suite
acl:
options:
max_acl_rules: 10
router_node_ref: router_1
ip_address_order:
- node_ref: domain_controller
nic_num: 1
- node_ref: web_server
nic_num: 1
- node_ref: database_server
nic_num: 1
- node_ref: backup_server
nic_num: 1
- node_ref: security_suite
nic_num: 1
- node_ref: client_1
nic_num: 1
- node_ref: client_2
nic_num: 1
- node_ref: security_suite
nic_num: 2
ics: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_SERVICE_SCAN
- type: NODE_SERVICE_STOP
- type: NODE_SERVICE_START
- type: NODE_SERVICE_PAUSE
- type: NODE_SERVICE_RESUME
- type: NODE_SERVICE_RESTART
- type: NODE_SERVICE_DISABLE
- type: NODE_SERVICE_ENABLE
- type: NODE_FILE_SCAN
- type: NODE_FILE_CHECKHASH
- type: NODE_FILE_DELETE
- type: NODE_FILE_REPAIR
- type: NODE_FILE_RESTORE
- type: NODE_FOLDER_SCAN
- type: NODE_FOLDER_CHECKHASH
- type: NODE_FOLDER_REPAIR
- type: NODE_FOLDER_RESTORE
- type: NODE_OS_SCAN
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: NODE_RESET
- type: NETWORK_ACL_ADDRULE
options:
target_router_ref: router_1
- type: NETWORK_ACL_REMOVERULE
options:
target_router_ref: router_1
- type: NETWORK_NIC_ENABLE
- type: NETWORK_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
# scan webapp service
1:
action: NODE_SERVICE_SCAN
options:
node_id: 2
service_id: 1
# stop webapp service
2:
action: NODE_SERVICE_STOP
options:
node_id: 2
service_id: 1
# start webapp service
3:
action: "NODE_SERVICE_START"
options:
node_id: 2
service_id: 1
4:
action: "NODE_SERVICE_PAUSE"
options:
node_id: 2
service_id: 1
5:
action: "NODE_SERVICE_RESUME"
options:
node_id: 2
service_id: 1
6:
action: "NODE_SERVICE_RESTART"
options:
node_id: 2
service_id: 1
7:
action: "NODE_SERVICE_DISABLE"
options:
node_id: 2
service_id: 1
8:
action: "NODE_SERVICE_ENABLE"
options:
node_id: 2
service_id: 1
9:
action: "NODE_FILE_SCAN"
options:
node_id: 3
folder_id: 1
file_id: 1
10:
action: "NODE_FILE_CHECKHASH"
options:
node_id: 3
folder_id: 1
file_id: 1
11:
action: "NODE_FILE_DELETE"
options:
node_id: 3
folder_id: 1
file_id: 1
12:
action: "NODE_FILE_REPAIR"
options:
node_id: 3
folder_id: 1
file_id: 1
13:
action: "NODE_FILE_RESTORE"
options:
node_id: 3
folder_id: 1
file_id: 1
14:
action: "NODE_FOLDER_SCAN"
options:
node_id: 3
folder_id: 1
15:
action: "NODE_FOLDER_CHECKHASH"
options:
node_id: 3
folder_id: 1
16:
action: "NODE_FOLDER_REPAIR"
options:
node_id: 3
folder_id: 1
17:
action: "NODE_FOLDER_RESTORE"
options:
node_id: 3
folder_id: 1
18:
action: "NODE_OS_SCAN"
options:
node_id: 3
19:
action: "NODE_SHUTDOWN"
options:
node_id: 6
20:
action: "NODE_STARTUP"
options:
node_id: 6
21:
action: "NODE_RESET"
options:
node_id: 6
22:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
23:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 1
source_port_id: 1
dest_port_id: 1
protocol_id: 1
24:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
25:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 3
source_port_id: 1
dest_port_id: 1
protocol_id: 3
26:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 7
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
27:
action: "NETWORK_ACL_ADDRULE"
options:
position: 1
permission: 2
source_ip_id: 8
dest_ip_id: 4
source_port_id: 1
dest_port_id: 1
protocol_id: 3
28:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 0
29:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 1
30:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 2
31:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 3
32:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 4
33:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 5
34:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 6
35:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 7
36:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 8
37:
action: "NETWORK_ACL_REMOVERULE"
options:
position: 9
38:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 1
nic_id: 1
39:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 1
nic_id: 1
40:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 2
nic_id: 1
41:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 2
nic_id: 1
42:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 3
nic_id: 1
43:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 3
nic_id: 1
44:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 4
nic_id: 1
45:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 4
nic_id: 1
46:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 1
47:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 1
48:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 5
nic_id: 2
49:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 5
nic_id: 2
50:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 6
nic_id: 1
51:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 6
nic_id: 1
52:
action: "NETWORK_NIC_DISABLE"
options:
node_id: 7
nic_id: 1
53:
action: "NETWORK_NIC_ENABLE"
options:
node_id: 7
nic_id: 1
options:
nodes:
- node_ref: router_1
- node_ref: switch_1
- node_ref: switch_2
- node_ref: domain_controller
- node_ref: web_server
- node_ref: database_server
- node_ref: backup_server
- node_ref: security_suite
- node_ref: client_1
- node_ref: client_2
max_folders_per_node: 2
max_files_per_folder: 2
max_services_per_node: 2
max_nics_per_node: 8
max_acl_rules: 10
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.5
options:
node_ref: database_server
folder_name: database
file_name: database.db
- type: WEB_SERVER_404_PENALTY
weight: 0.5
options:
node_ref: web_server
service_ref: web_server_web_service
agent_settings:
# ...
simulation:
network:
nodes:
- ref: router_1
type: router
hostname: router_1
num_ports: 5
ports:
1:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
2:
ip_address: 192.168.1.1
subnet_mask: 255.255.255.0
acl:
0:
action: PERMIT
src_port: POSTGRES_SERVER
dst_port: POSTGRES_SERVER
1:
action: PERMIT
src_port: DNS
dst_port: DNS
22:
action: PERMIT
src_port: ARP
dst_port: ARP
23:
action: PERMIT
protocol: ICMP
- ref: switch_1
type: switch
hostname: switch_1
num_ports: 8
- ref: switch_2
type: switch
hostname: switch_2
num_ports: 8
- ref: domain_controller
type: server
hostname: domain_controller
ip_address: 192.168.1.10
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- ref: domain_controller_dns_server
type: DNSServer
options:
domain_mapping:
arcd.com: 192.168.1.12 # web server
- ref: web_server
type: server
hostname: web_server
ip_address: 192.168.1.12
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.10
dns_server: 192.168.1.10
services:
- ref: web_server_database_client
type: DatabaseClient
options:
db_server_ip: 192.168.1.14
- ref: web_server_web_service
type: WebServer
- ref: database_server
type: server
hostname: database_server
ip_address: 192.168.1.14
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: database_service
type: DatabaseService
- ref: backup_server
type: server
hostname: backup_server
ip_address: 192.168.1.16
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
services:
- ref: backup_service
type: DatabaseBackup
- ref: security_suite
type: server
hostname: security_suite
ip_address: 192.168.1.110
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
dns_server: 192.168.1.10
nics:
2: # unfortunately this number is currently meaningless, they're just added in order and take up the next available slot
ip_address: 192.168.10.110
subnet_mask: 255.255.255.0
- ref: client_1
type: computer
hostname: client_1
ip_address: 192.168.10.21
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
services:
- ref: data_manipulation_bot
type: DataManipulationBot
- ref: client_1_dns_client
type: DNSClient
- ref: client_2
type: computer
hostname: client_2
ip_address: 192.168.10.22
subnet_mask: 255.255.255.0
default_gateway: 192.168.10.1
dns_server: 192.168.1.10
applications:
- ref: client_2_web_browser
type: WebBrowser
services:
- ref: client_2_dns_client
type: DNSClient
links:
- ref: router_1___switch_1
endpoint_a_ref: router_1
endpoint_a_port: 1
endpoint_b_ref: switch_1
endpoint_b_port: 8
- ref: router_1___switch_2
endpoint_a_ref: router_1
endpoint_a_port: 2
endpoint_b_ref: switch_2
endpoint_b_port: 8
- ref: switch_1___domain_controller
endpoint_a_ref: switch_1
endpoint_a_port: 1
endpoint_b_ref: domain_controller
endpoint_b_port: 1
- ref: switch_1___web_server
endpoint_a_ref: switch_1
endpoint_a_port: 2
endpoint_b_ref: web_server
endpoint_b_port: 1
- ref: switch_1___database_server
endpoint_a_ref: switch_1
endpoint_a_port: 3
endpoint_b_ref: database_server
endpoint_b_port: 1
- ref: switch_1___backup_server
endpoint_a_ref: switch_1
endpoint_a_port: 4
endpoint_b_ref: backup_server
endpoint_b_port: 1
- ref: switch_1___security_suite
endpoint_a_ref: switch_1
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 1
- ref: switch_2___client_1
endpoint_a_ref: switch_2
endpoint_a_port: 1
endpoint_b_ref: client_1
endpoint_b_port: 1
- ref: switch_2___client_2
endpoint_a_ref: switch_2
endpoint_a_port: 2
endpoint_b_ref: client_2
endpoint_b_port: 1
- ref: switch_2___security_suite
endpoint_a_ref: switch_2
endpoint_a_port: 7
endpoint_b_ref: security_suite
endpoint_b_port: 2

View File

@@ -1,15 +1,12 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import datetime
import shutil
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Union
from unittest.mock import patch
import pytest
import yaml
from primaite import getLogger
from primaite.game.session import PrimaiteSession
# from primaite.environment.primaite_env import Primaite
# from primaite.primaite_session import PrimaiteSession
@@ -19,13 +16,15 @@ from primaite.simulator.network.transmission.transport_layer import Port
from primaite.simulator.system.applications.application import Application
from primaite.simulator.system.core.sys_log import SysLog
from primaite.simulator.system.services.service import Service
from tests.mock_and_patch.get_session_path_mock import get_temp_session_path
from tests.mock_and_patch.get_session_path_mock import temp_user_sessions_path
ACTION_SPACE_NODE_VALUES = 1
ACTION_SPACE_NODE_ACTION_VALUES = 1
_LOGGER = getLogger(__name__)
from primaite import PRIMAITE_PATHS
# PrimAITE v3 stuff
from primaite.simulator.file_system.file_system import FileSystem
from primaite.simulator.network.hardware.base import Node
@@ -70,102 +69,32 @@ def file_system() -> FileSystem:
# PrimAITE v2 stuff
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
class TempPrimaiteSession: # PrimaiteSession):
class TempPrimaiteSession(PrimaiteSession):
"""
A temporary PrimaiteSession class.
Uses context manager for deletion of files upon exit.
"""
# def __init__(
# self,
# training_config_path: Union[str, Path],
# lay_down_config_path: Union[str, Path],
# ):
# super().__init__(training_config_path, lay_down_config_path)
# self.setup()
@classmethod
def from_config(cls, config_path: Union[str, Path]) -> "TempPrimaiteSession":
"""Create a temporary PrimaiteSession object from a config file."""
config_path = Path(config_path)
with open(config_path, "r") as f:
config = yaml.safe_load(f)
# @property
# def env(self) -> Primaite:
# """Direct access to the env for ease of testing."""
# return self._agent_session._env # noqa
return super().from_config(cfg=config)
# def __enter__(self):
# return self
def __enter__(self):
return self
# def __exit__(self, type, value, tb):
# shutil.rmtree(self.session_path)
# _LOGGER.debug(f"Deleted temp session directory: {self.session_path}")
def __exit__(self, type, value, tb):
pass
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
@pytest.fixture
def temp_primaite_session(request):
"""
Provides a temporary PrimaiteSession instance.
It's temporary as it uses a temporary directory as the session path.
To use this fixture you need to:
- parametrize your test function with:
- "temp_primaite_session"
- [[path to training config, path to lay down config]]
- Include the temp_primaite_session fixture as a param in your test
function.
- use the temp_primaite_session as a context manager assigning is the
name 'session'.
.. code:: python
from primaite.config.lay_down_config import dos_very_basic_config_path
from primaite.config.training_config import main_training_config_path
@pytest.mark.parametrize(
"temp_primaite_session",
[
[main_training_config_path(), dos_very_basic_config_path()]
],
indirect=True
)
def test_primaite_session(temp_primaite_session):
with temp_primaite_session as session:
# Learning outputs are saved in session.learning_path
session.learn()
# Evaluation outputs are saved in session.evaluation_path
session.evaluate()
# To ensure that all files are written, you must call .close()
session.close()
# If you need to inspect any session outputs, it must be done
# inside the context manager
# Now that we've exited the context manager, the
# session.session_path directory and its contents are deleted
"""
training_config_path = request.param[0]
lay_down_config_path = request.param[1]
with patch("primaite.agents.agent_abc.get_session_path", get_temp_session_path) as mck:
mck.session_timestamp = datetime.now()
return TempPrimaiteSession(training_config_path, lay_down_config_path)
@pytest.mark.skip("Deprecated") # TODO: implement a similar test for primaite v3
@pytest.fixture
def temp_session_path() -> Path:
"""
Get a temp directory session path the test session will output to.
:return: The session directory path.
"""
session_timestamp = datetime.now()
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_path = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path(tempfile.gettempdir()) / "_primaite" / date_dir / session_path
session_path.mkdir(exist_ok=True, parents=True)
return session_path
def temp_primaite_session(request, monkeypatch) -> TempPrimaiteSession:
"""Create a temporary PrimaiteSession object."""
monkeypatch.setattr(PRIMAITE_PATHS, "user_sessions_path", temp_user_sessions_path())
config_path = request.param[0]
return TempPrimaiteSession.from_config(config_path=config_path)

View File

@@ -0,0 +1,83 @@
import pydantic
import pytest
from tests import TEST_ASSETS_ROOT
from tests.conftest import TempPrimaiteSession
CFG_PATH = TEST_ASSETS_ROOT / "configs/test_primaite_session.yaml"
TRAINING_ONLY_PATH = TEST_ASSETS_ROOT / "configs/train_only_primaite_session.yaml"
EVAL_ONLY_PATH = TEST_ASSETS_ROOT / "configs/eval_only_primaite_session.yaml"
MISCONFIGURED_PATH = TEST_ASSETS_ROOT / "configs/bad_primaite_session.yaml"
class TestPrimaiteSession:
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_creating_session(self, temp_primaite_session):
"""Check that creating a session from config works."""
with temp_primaite_session as session:
if not isinstance(session, TempPrimaiteSession):
raise AssertionError
assert session is not None
assert session.simulation
assert len(session.agents) == 3
assert len(session.rl_agents) == 1
assert session.policy
assert session.env
assert session.simulation.network
assert len(session.simulation.network.nodes) == 10
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_start_session(self, temp_primaite_session):
"""Make sure you can go all the way through the session without errors."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
session_path = session.io_manager.session_path
assert session_path.exists()
print(list(session_path.glob("*")))
checkpoint_dir = session_path / "checkpoints" / "sb3_final"
assert checkpoint_dir.exists()
checkpoint_1 = checkpoint_dir / "sb3_model_640_steps.zip"
checkpoint_2 = checkpoint_dir / "sb3_model_1280_steps.zip"
checkpoint_3 = checkpoint_dir / "sb3_model_1920_steps.zip"
assert checkpoint_1.exists()
assert checkpoint_2.exists()
assert not checkpoint_3.exists()
@pytest.mark.parametrize("temp_primaite_session", [[TRAINING_ONLY_PATH]], indirect=True)
def test_training_only_session(self, temp_primaite_session):
"""Check that you can run a training-only session."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
# TODO: include checks that the model was trained, e.g. that the loss changed and checkpoints were saved?
@pytest.mark.parametrize("temp_primaite_session", [[EVAL_ONLY_PATH]], indirect=True)
def test_eval_only_session(self, temp_primaite_session):
"""Check that you can load a model and run an eval-only session."""
with temp_primaite_session as session:
session: TempPrimaiteSession
session.start_session()
# TODO: include checks that the model was loaded and that the eval-only session ran
def test_error_thrown_on_bad_configuration(self):
with pytest.raises(pydantic.ValidationError):
session = TempPrimaiteSession.from_config(MISCONFIGURED_PATH)
@pytest.mark.parametrize("temp_primaite_session", [[CFG_PATH]], indirect=True)
def test_session_sim_reset(self, temp_primaite_session):
with temp_primaite_session as session:
session: TempPrimaiteSession
client_1 = session.simulation.network.get_node_by_hostname("client_1")
client_1.software_manager.uninstall("DataManipulationBot")
assert "DataManipulationBot" not in client_1.software_manager.software
session.reset()
client_1 = session.simulation.network.get_node_by_hostname("client_1")
assert "DataManipulationBot" in client_1.software_manager.software

View File

@@ -19,16 +19,16 @@ def test_data_manipulation(uc2_network):
db_service.backup_database()
# First check that the DB client on the web_server can successfully query the users table on the database
assert db_client.query("SELECT * FROM user;")
assert db_client.query("SELECT")
# Now we run the DataManipulationBot
db_manipulation_bot.run()
# Now check that the DB client on the web_server cannot query the users table on the database
assert not db_client.query("SELECT * FROM user;")
assert not db_client.query("SELECT")
# Now restore the database
db_service.restore_backup()
# Now check that the DB client on the web_server can successfully query the users table on the database
assert db_client.query("SELECT * FROM user;")
assert db_client.query("SELECT")

View File

@@ -57,7 +57,7 @@ def test_database_client_query(uc2_network):
db_client: DatabaseClient = web_server.software_manager.software["DatabaseClient"]
db_client.connect()
assert db_client.query("SELECT * FROM user;")
assert db_client.query("SELECT")
def test_create_database_backup(uc2_network):

View File

@@ -10,7 +10,7 @@ def test_web_page_home_page(uc2_network):
"""Test to see if the browser is able to open the main page of the web server."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.execute()
web_client.run()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage("http://arcd.com/") is True
@@ -24,7 +24,7 @@ def test_web_page_get_users_page_request_with_domain_name(uc2_network):
"""Test to see if the client can handle requests with domain names"""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.execute()
web_client.run()
assert web_client.operating_state == ApplicationOperatingState.RUNNING
assert web_client.get_webpage("http://arcd.com/users/") is True
@@ -38,7 +38,7 @@ def test_web_page_get_users_page_request_with_ip_address(uc2_network):
"""Test to see if the client can handle requests that use ip_address."""
client_1: Computer = uc2_network.get_node_by_hostname("client_1")
web_client: WebBrowser = client_1.software_manager.software["WebBrowser"]
web_client.execute()
web_client.run()
web_server: Server = uc2_network.get_node_by_hostname("web_server")
web_server_ip = web_server.nics.get(next(iter(web_server.nics))).ip_address

View File

@@ -9,7 +9,7 @@ from primaite import getLogger
_LOGGER = getLogger(__name__)
def get_temp_session_path(session_timestamp: datetime) -> Path:
def temp_user_sessions_path() -> Path:
"""
Get a temp directory session path the test session will output to.

View File

@@ -27,7 +27,7 @@ def test_create_dm_bot(dm_client):
assert data_manipulation_bot.name == "DataManipulationBot"
assert data_manipulation_bot.port == Port.POSTGRES_SERVER
assert data_manipulation_bot.protocol == IPProtocol.TCP
assert data_manipulation_bot.payload == "DROP TABLE IF EXISTS user;"
assert data_manipulation_bot.payload == "DELETE"
def test_dm_bot_logon(dm_bot):