Merged PR 353: Episode Schedule

## Summary
* Made It Possible™ to vary the layout and agents between episodes.
* Standardised the environments to all use `env_config` as the __init__ parameter name, previously PrimaiteGymEnv was using `game_config`
* Added a notebook that demonstrates how to use the variable episodes.

## Test process
Checked that existing pytests run. Added more tests. Checked that notebooks run.

## Checklist
- [x] PR is linked to a **work item**
- [x] **acceptance criteria** of linked ticket are met
- [x] performed **self-review** of the code
- [x] written **tests** for any new functionality added with this PR
- [x] updated the **documentation** if this PR changes or adds functionality
- [ ] written/updated **design docs** if this PR implements new functionality
- [x] updated the **change log**
- [x] ran **pre-commit** checks for code style
- [x] attended to any **TO-DOs** left in the code

Related work items: #2269, #2334, #2336, #2475, #2476
This commit is contained in:
Marek Wolan
2024-04-29 09:08:37 +00:00
31 changed files with 1268 additions and 29 deletions

View File

@@ -3,6 +3,7 @@ repos:
rev: v4.4.0
hooks:
- id: check-yaml
exclude: scenario_with_placeholders/
- id: end-of-file-fixer
- id: trailing-whitespace
- id: check-added-large-files

View File

@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Upgraded pydantic to version 2.7.0
- Upgraded Ray to version >= 2.9
- Added ipywidgets to the dependencies
- Added ability to define scenarios that change depending on the episode number.
- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config`
- Database Connection ID's are now created/issued by DatabaseService and not DatabaseClient
- added ability to set PrimAITE between development and production modes via PrimAITE CLI ``mode`` command
- Updated DatabaseClient so that it can now have a single native DatabaseClientConnection along with a collection of DatabaseClientConnection's.

View File

@@ -5,7 +5,7 @@
PrimAITE |VERSION| Configuration
********************************
PrimAITE uses a single configuration file to define everything needed to create the training environment for RL agents, including the network, the scripted agents, and the RL agent's action space, observation space, and reward function.
PrimAITE uses YAML configuration files to define everything needed to create the training environment for RL agents, including the network, the scripted agents, and the RL agent's action space, observation space, and reward function.
Example Configuration Hierarchy
###############################
@@ -34,3 +34,8 @@ Configurable items
configuration/game.rst
configuration/agents.rst
configuration/simulation.rst
Varying The Configuration Each Episode
######################################
PrimAITE allows for the configuration to be varied each episode. This is done by specifying a configuration folder instead of a single file. A full explanation is provided in the notebook `Using-Episode-Schedules.ipynb`. Please find the notebook in the user notebooks directory.

View File

@@ -0,0 +1,2 @@
# No green agents present
greens: &greens []

View File

@@ -0,0 +1,34 @@
agents: &greens
- ref: green_A
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.2
1: 0.8
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DatabaseClient
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
reward_function:
reward_components:
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 1.0
options:
node_hostname: client

View File

@@ -0,0 +1,34 @@
agents: &greens
- ref: green_B
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.95
1: 0.05
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DatabaseClient
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
reward_function:
reward_components:
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 1.0
options:
node_hostname: client

View File

@@ -0,0 +1,2 @@
# No red agents present
reds: &reds []

View File

@@ -0,0 +1,26 @@
reds: &reds
- ref: red_A
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DataManipulationBot
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 10
frequency: 10
variance: 0

View File

@@ -0,0 +1,26 @@
reds: &reds
- ref: red_B
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DataManipulationBot
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 3
frequency: 2
variance: 1

View File

@@ -0,0 +1,168 @@
io_settings:
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: false
game:
max_episode_length: 128
ports:
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- *greens
- *reds
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
routers: []
hosts:
- hostname: client
- hostname: server
num_services: 1
num_applications: 1
num_folders: 1
num_files: 1
num_nics: 1
include_num_access: false
include_nmne: true
- type: LINKS
label: LINKS
options:
link_references:
- client:eth-1<->switch_1:eth-1
- server:eth-1<->switch_1:eth-2
action_space:
action_list:
- type: DONOTHING
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_SHUTDOWN
options:
node_id: 0
2:
action: NODE_SHUTDOWN
options:
node_id: 1
3:
action: NODE_STARTUP
options:
node_id: 0
4:
action: NODE_STARTUP
options:
node_id: 1
5:
action: HOST_NIC_DISABLE
options:
node_id: 0
nic_id: 0
6:
action: HOST_NIC_DISABLE
options:
node_id: 1
nic_id: 0
7:
action: HOST_NIC_ENABLE
options:
node_id: 0
nic_id: 0
8:
action: HOST_NIC_ENABLE
options:
node_id: 1
nic_id: 0
options:
nodes:
- node_name: client
- node_name: server
max_folders_per_node: 0
max_files_per_folder: 0
max_services_per_node: 0
max_nics_per_node: 1
max_acl_rules: 0
ip_list:
- 192.168.1.2
- 192.168.1.3
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.40
options:
node_hostname: database_server
folder_name: database
file_name: database.db
agent_settings:
flatten_obs: false
simulation:
network:
nodes:
- hostname: client
type: computer
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.3
- type: DataManipulationBot
options:
server_ip: 192.168.1.3
payload: "DELETE"
- hostname: switch_1
type: switch
num_ports: 2
- hostname: server
type: server
ip_address: 192.168.1.3
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- type: DatabaseService
links:
- endpoint_a_hostname: client
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: server
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2

View File

@@ -0,0 +1,14 @@
base_scenario: scenario.yaml
schedule:
0:
- greens_0.yaml
- reds_0.yaml
1:
- greens_0.yaml
- reds_1.yaml
2:
- greens_1.yaml
- reds_1.yaml
3:
- greens_2.yaml
- reds_2.yaml

View File

@@ -0,0 +1,372 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Using Episode Schedules\n",
"\n",
"PrimAITE supports the ability to use different variations on a scenario at different episodes. This can be used to increase \n",
"domain randomisation to prevent overfitting, or to set up curriculum learning to train agents to perform more complicated tasks.\n",
"\n",
"When using a fixed scenario, a single yaml config file is used. However, to use episode schedules, PrimAITE uses a \n",
"directory with several config files that work together."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Defining variations in the config file.\n",
"\n",
"### Base scenario\n",
"The base scenario is essentially the same as a fixed YAML configuration, but it can contain placeholders that are \n",
"populated with episode-specific data at runtime. The base scenario contains any network, agent, or settings that\n",
"remain fixed for the entire training/evaluation session.\n",
"\n",
"The placeholders are defined as YAML Aliases and they are denoted by an asterisk (`*placeholder`).\n",
"\n",
"### Variations\n",
"For each variation that could be used in a placeholder, there is a separate yaml file that contains the data that should populate the placeholder.\n",
"\n",
"The data that fills the placeholder is defined as a YAML Anchor in a separate file, denoted by an ampersand (`&anchor`).\n",
"\n",
"[Learn more about YAML Aliases and Anchors here.](https://www.educative.io/blog/advanced-yaml-syntax-cheatsheet#:~:text=YAML%20Anchors%20and%20Alias)\n",
"\n",
"### Schedule\n",
"Users must define which combination of scenario variations should be loaded in each episode. This takes the form of a\n",
"YAML file with a relative path to the base scenario and a list of paths to be loaded in during each episode.\n",
"\n",
"It takes the following format:\n",
"```yaml\n",
"base_scenario: base.yaml\n",
"schedule:\n",
" 0: # list of variations to load in at episode 0 (before the first call to env.reset() happens)\n",
" - laydown_1.yaml\n",
" - attack_1.yaml\n",
" 1: # list of variations to load in at episode 1 (after the first env.reset() call)\n",
" - laydown_2.yaml\n",
" - attack_2.yaml\n",
"```\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Demonstration"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Run `primaite setup` to copy the example config files into the correct directory. Then, import and define config location."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!primaite setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import yaml\n",
"from primaite.session.environment import PrimaiteGymEnv\n",
"from primaite import PRIMAITE_PATHS\n",
"from prettytable import PrettyTable\n",
"scenario_path = PRIMAITE_PATHS.user_config_path / \"example_config/scenario_with_placeholders\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Base Scenario File\n",
"Let's view the contents of the base scenario file:\n",
"\n",
"It contains all the base settings that stay fixed throughout all episodes, including the `io_settings`, `game` settings, the network layout and the blue agent definition. There are two placeholders: `*greens` and `*reds`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(scenario_path/\"scenario.yaml\") as f:\n",
" print(f.read())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Schedule File\n",
"Let's view the contents of the schedule file:\n",
"\n",
"This file references the base scenario file and defines which variations should be loaded in at each episode. In this instance, there are four episodes, during the first episode `greens_0` and `reds_0` is used, during the second episode `greens_0` and `reds_1` is used, and so on."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(scenario_path/\"schedule.yaml\") as f:\n",
" print(f.read())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Green Agent Variation Files\n",
"\n",
"There are three different variants of the green agent setup. In `greens_0`, there are no green agents, in `greens_1` there is a green agent that executes the database client application 80% of the time, and in `greens_2` there is a green agent that executes the database client application 5% of the time.\n",
"\n",
"(the difference between `greens_1` and `greens_2` is in the agent name and action probabilities)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(scenario_path/\"greens_0.yaml\") as f:\n",
" print(f.read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(scenario_path/\"greens_1.yaml\") as f:\n",
" print(f.read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(scenario_path/\"greens_2.yaml\") as f:\n",
" print(f.read())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Red Agent Variation Files\n",
"\n",
"There are three different variants of the red agent setup. In `reds_0`, there are no red agents, in `reds_1` there is a red agent that executes every 20 steps, but in `reds_2` there is a red agent that executes every 2 steps."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(scenario_path/\"reds_0.yaml\") as f:\n",
" print(f.read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(scenario_path/\"reds_1.yaml\") as f:\n",
" print(f.read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"with open(scenario_path/\"reds_2.yaml\") as f:\n",
" print(f.read())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Running the simulation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Create the environment using the variable config."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env = PrimaiteGymEnv(env_config=scenario_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Episode 0\n",
"Let' run the episodes to verify that the agents are changing as expected. In episode 0, there should be no green or red agents, just the defender blue agent."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print(f\"Current episode number: {env.episode_counter}\")\n",
"print(f\"Agents present: {list(env.game.agents.keys())}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Episode 1\n",
"When we reset the environment, it moves onto episode 1, where it will bring in reds_1 for red agent definition.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.reset()\n",
"print(f\"Current episode number: {env.episode_counter}\")\n",
"print(f\"Agents present: {list(env.game.agents.keys())}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Episode 2\n",
"When we reset the environment again, it moves onto episode 2, where it will bring in greens_1 and reds_1 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
"\n",
"Most green actions will be `NODE_APPLICATION_EXECUTE` while red will `DONOTHING` except at steps 10 and 20."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.reset()\n",
"print(f\"Current episode number: {env.episode_counter}\")\n",
"print(f\"Agents present: {list(env.game.agents.keys())}\")\n",
"for i in range(21):\n",
" env.step(0)\n",
"\n",
"table = PrettyTable()\n",
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
"for i in range(21):\n",
" green_action = env.game.agents['green_A'].action_history[i].action\n",
" red_action = env.game.agents['red_A'].action_history[i].action\n",
" table.add_row([i, green_action, red_action])\n",
"print(table)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Episode 3\n",
"When we reset the environment again, it moves onto episode 3, where it will bring in greens_2 and reds_2 for green and red agent definitions. Let's verify the agent names and that they take actions at the defined frequency.\n",
"\n",
"Now, green will perform `NODE_APPLICATION_EXECUTE` only 5% of the time, while red will perform `NODE_APPLICATION_EXECUTE` more frequently than before."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.reset()\n",
"print(f\"Current episode number: {env.episode_counter}\")\n",
"print(f\"Agents present: {list(env.game.agents.keys())}\")\n",
"for i in range(21):\n",
" env.step(0)\n",
"\n",
"table = PrettyTable()\n",
"table.field_names = [\"step\", \"Green Action\", \"Red Action\"]\n",
"for i in range(21):\n",
" green_action = env.game.agents['green_B'].action_history[i].action\n",
" red_action = env.game.agents['red_B'].action_history[i].action\n",
" table.add_row([i, green_action, red_action])\n",
"print(table)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Further Episodes\n",
"\n",
"Since the schedule definition only goes up to episode 3, if we reset the environment again, we run out of episodes. The environment will simply loop back to the beginning, but it produces a warning message to make users aware that the episodes are being repeated."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env.reset(); # semicolon suppresses jupyter outputting the observation space.\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -1,6 +1,6 @@
import copy
import json
from typing import Any, Dict, Optional, SupportsFloat, Tuple
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -9,6 +9,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite import getLogger
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO
from primaite.simulator import SIM_OUTPUT
@@ -23,14 +24,14 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, game_config: Dict):
def __init__(self, env_config: Union[Dict, str, PathLike]):
"""Initialise the environment."""
super().__init__()
self.game_config: Dict = game_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
@@ -91,9 +92,9 @@ class PrimaiteGymEnv(gymnasium.Env):
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
self.game.update_agents(state=state)
next_obs = self._get_obs()
@@ -138,8 +139,8 @@ class PrimaiteRayEnv(gymnasium.Env):
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(game_config=env_config)
self.env.episode_counter -= 1
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
@@ -155,6 +156,11 @@ class PrimaiteRayEnv(gymnasium.Env):
"""Close the simulation."""
self.env.close()
@property
def game(self) -> PrimaiteGame:
"""Pass through game from env."""
return self.env.game
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
@@ -166,16 +172,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
which is the PrimaiteGame instance.
:type env_config: Dict
"""
self.game_config: Dict = env_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
self.episode_counter: int = 0
"""Current episode number."""
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.episode_counter: int = 0
"""Current episode number."""
self.terminateds = set()
self.truncateds = set()
@@ -201,9 +207,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()

View File

@@ -0,0 +1,123 @@
import copy
from abc import ABC, abstractmethod
from itertools import chain
from pathlib import Path
from typing import Dict, List, Mapping, Sequence, Union
import pydantic
import yaml
from primaite import getLogger
_LOGGER = getLogger(__name__)
class EpisodeScheduler(pydantic.BaseModel, ABC):
"""
Episode schedulers provide functionality to select different scenarios and game setups for each episode.
This is useful when implementing advanced RL concepts like curriculum learning and domain randomisation.
"""
@abstractmethod
def __call__(self, episode_num: int) -> Dict:
"""Return the config that should be used during this episode."""
...
class ConstantEpisodeScheduler(EpisodeScheduler):
"""The constant episode schedule simply provides the same game setup every time."""
config: Dict
def __call__(self, episode_num: int) -> Dict:
"""Return the same config every time."""
return copy.deepcopy(self.config)
class EpisodeListScheduler(EpisodeScheduler):
"""Cycle through a list of different game setups for each episode."""
schedule: Mapping[int, List[str]]
"""Mapping from episode number to list of filenames"""
episode_data: Mapping[str, str]
"""Mapping from filename to yaml string."""
base_scenario: str
"""yaml string containing the base scenario."""
_exceeded_episode_list: bool = False
"""
Flag that's set to true when attempting to keep generating episodes after schedule runs out.
When this happens, we loop back to the beginning, but a warning is raised.
"""
def __call__(self, episode_num: int) -> Dict:
"""Return the config for the given episode number."""
if episode_num >= len(self.schedule):
if not self._exceeded_episode_list:
self._exceeded_episode_list = True
_LOGGER.warn(
f"Running episode {episode_num} but the schedule only defines "
f"{len(self.schedule)} episodes. Looping back to the beginning"
)
# not sure if we should be using a traditional warning, or a _LOGGER.warning
episode_num = episode_num % len(self.schedule)
filenames_to_join = self.schedule[episode_num]
yaml_data_to_join = [self.episode_data[fn] for fn in filenames_to_join] + [self.base_scenario]
joined_yaml = "\n".join(yaml_data_to_join)
parsed_cfg = yaml.safe_load(joined_yaml)
# Unfortunately, using placeholders like this is slightly hacky, so we have to flatten the list of agents
flat_agents_list = []
for a in parsed_cfg["agents"]:
if isinstance(a, Sequence):
flat_agents_list.extend(a)
else:
flat_agents_list.append(a)
parsed_cfg["agents"] = flat_agents_list
return parsed_cfg
def build_scheduler(config: Union[str, Path, Dict]) -> EpisodeScheduler:
"""
Convenience method to build an EpisodeScheduler with a dict, file path, or folder path.
If a path to a folder is provided, it will be treated as a list of game scenarios.
Otherwise, if a dict or a single file is provided, it will be treated as a constant game scenario.
"""
# If we get a dict, return a constant episode schedule that repeats that one config forever
if isinstance(config, Dict):
return ConstantEpisodeScheduler(config=config)
# Cast string to Path
if isinstance(config, str):
config = Path(config)
if not config.exists():
raise FileNotFoundError(f"Provided config path {config} could not be found.")
if config.is_file():
with open(config, "r") as f:
cfg_data = yaml.safe_load(f)
return ConstantEpisodeScheduler(config=cfg_data)
if not config.is_dir():
raise RuntimeError("Something went wrong while building Primaite config.")
root = config
schedule_path = root / "schedule.yaml"
with open(schedule_path, "r") as f:
schedule = yaml.safe_load(f)
base_scenario_path = root / schedule["base_scenario"]
files_to_load = set(chain.from_iterable(schedule["schedule"].values()))
episode_data = {fp: (root / fp).read_text() for fp in files_to_load}
return EpisodeListScheduler(
schedule=schedule["schedule"], episode_data=episode_data, base_scenario=base_scenario_path.read_text()
)

View File

@@ -0,0 +1,2 @@
# No green agents present
greens: &greens []

View File

@@ -0,0 +1,34 @@
agents: &greens
- ref: green_A
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.2
1: 0.8
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DatabaseClient
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
reward_function:
reward_components:
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 1.0
options:
node_hostname: client

View File

@@ -0,0 +1,34 @@
agents: &greens
- ref: green_B
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.95
1: 0.05
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DatabaseClient
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
reward_function:
reward_components:
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 1.0
options:
node_hostname: client

View File

@@ -0,0 +1,2 @@
# No red agents present
reds: &reds []

View File

@@ -0,0 +1,26 @@
reds: &reds
- ref: red_A
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DataManipulationBot
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 10
frequency: 10
variance: 0

View File

@@ -0,0 +1,26 @@
reds: &reds
- ref: red_B
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DataManipulationBot
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 3
frequency: 2
variance: 1

View File

@@ -0,0 +1,168 @@
io_settings:
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: false
game:
max_episode_length: 128
ports:
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- *greens
- *reds
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
routers: []
hosts:
- hostname: client
- hostname: server
num_services: 1
num_applications: 1
num_folders: 1
num_files: 1
num_nics: 1
include_num_access: false
include_nmne: true
- type: LINKS
label: LINKS
options:
link_references:
- client:eth-1<->switch_1:eth-1
- server:eth-1<->switch_1:eth-2
action_space:
action_list:
- type: DONOTHING
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_SHUTDOWN
options:
node_id: 0
2:
action: NODE_SHUTDOWN
options:
node_id: 1
3:
action: NODE_STARTUP
options:
node_id: 0
4:
action: NODE_STARTUP
options:
node_id: 1
5:
action: HOST_NIC_DISABLE
options:
node_id: 0
nic_id: 0
6:
action: HOST_NIC_DISABLE
options:
node_id: 1
nic_id: 0
7:
action: HOST_NIC_ENABLE
options:
node_id: 0
nic_id: 0
8:
action: HOST_NIC_ENABLE
options:
node_id: 1
nic_id: 0
options:
nodes:
- node_name: client
- node_name: server
max_folders_per_node: 0
max_files_per_folder: 0
max_services_per_node: 0
max_nics_per_node: 1
max_acl_rules: 0
ip_list:
- 192.168.1.2
- 192.168.1.3
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.40
options:
node_hostname: database_server
folder_name: database
file_name: database.db
agent_settings:
flatten_obs: false
simulation:
network:
nodes:
- hostname: client
type: computer
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.3
- type: DataManipulationBot
options:
server_ip: 192.168.1.3
payload: "DELETE"
- hostname: switch_1
type: switch
num_ports: 2
- hostname: server
type: server
ip_address: 192.168.1.3
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- type: DatabaseService
links:
- endpoint_a_hostname: client
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: server
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2

View File

@@ -0,0 +1,14 @@
base_scenario: scenario.yaml
schedule:
0:
- greens_0.yaml
- reds_0.yaml
1:
- greens_0.yaml
- reds_1.yaml
2:
- greens_1.yaml
- reds_1.yaml
3:
- greens_2.yaml
- reds_2.yaml

View File

@@ -16,7 +16,7 @@ def test_sb3_compatibility():
with open(data_manipulation_config_path(), "r") as f:
cfg = yaml.safe_load(f)
gym = PrimaiteGymEnv(game_config=cfg)
gym = PrimaiteGymEnv(env_config=cfg)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)

View File

@@ -21,7 +21,7 @@ class TestPrimaiteEnvironment:
"""Check that environment loads correctly from config and it can be reset."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
def env_checks():
assert env is not None
@@ -44,7 +44,7 @@ class TestPrimaiteEnvironment:
"""Make sure you can go all the way through the session without errors."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
assert (num_actions := len(env.agent.action_manager.action_map)) == 54
# run every action and make sure there's no crash
@@ -88,4 +88,4 @@ class TestPrimaiteEnvironment:
with open(MISCONFIGURED_PATH, "r") as f:
cfg = yaml.safe_load(f)
with pytest.raises(pydantic.ValidationError):
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)

View File

@@ -44,7 +44,7 @@ def test_application_install_uninstall_on_uc2():
with open(TEST_ASSETS_ROOT / "configs/test_application_install.yaml", "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
env.agent.flatten_obs = False
env.reset()

View File

@@ -0,0 +1,68 @@
import pytest
import yaml
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from tests.conftest import TEST_ASSETS_ROOT
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"
single_yaml_config = TEST_ASSETS_ROOT / "configs" / "test_primaite_session.yaml"
with open(single_yaml_config, "r") as f:
config_dict = yaml.safe_load(f)
@pytest.mark.parametrize("env_type", [PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv])
def test_creating_env_with_folder(env_type):
"""Check that the environment can be created with a folder path."""
def check_taking_steps(e):
if isinstance(e, PrimaiteRayMARLEnv):
for i in range(9):
e.step({k: i for k in e.game.rl_agents})
else:
for i in range(9):
e.step(i)
env = env_type(env_config=folder_path)
assert env is not None
for _ in range(3): # do it multiple times to ensure it loops back to the beginning
assert len(env.game.agents) == 1
assert "defender" in env.game.agents
check_taking_steps(env)
env.reset()
assert len(env.game.agents) == 2
assert "defender" in env.game.agents
assert "red_A" in env.game.agents
check_taking_steps(env)
env.reset()
assert len(env.game.agents) == 3
assert all([a in env.game.agents for a in ["defender", "green_A", "red_A"]])
check_taking_steps(env)
env.reset()
assert len(env.game.agents) == 3
assert all([a in env.game.agents for a in ["defender", "green_B", "red_B"]])
check_taking_steps(env)
env.reset()
@pytest.mark.parametrize(
"env_data, env_type",
[
(single_yaml_config, PrimaiteGymEnv),
(single_yaml_config, PrimaiteRayEnv),
(single_yaml_config, PrimaiteRayMARLEnv),
(config_dict, PrimaiteGymEnv),
(config_dict, PrimaiteRayEnv),
(config_dict, PrimaiteRayMARLEnv),
],
)
def test_creating_env_with_static_config(env_data, env_type):
"""Check that the environment can be created with a single yaml file."""
env = env_type(env_config=single_yaml_config)
assert env is not None
agents_before = len(env.game.agents)
env.reset()
assert len(env.game.agents) == agents_before

View File

@@ -24,7 +24,7 @@ def test_io_settings():
"""Test that the io_settings are loaded correctly."""
with open(BASIC_CONFIG, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
assert env.io.settings is not None

View File

@@ -507,7 +507,7 @@ def test_firewall_acl_add_remove_rule_integration():
with open(FIREWALL_ACTIONS_NETWORK, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
# 1: Check that traffic is normal and acl starts off with 4 rules.
firewall = env.game.simulation.network.get_node_by_hostname("firewall")
@@ -598,7 +598,7 @@ def test_firewall_port_disable_enable_integration():
with open(FIREWALL_ACTIONS_NETWORK, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
firewall = env.game.simulation.network.get_node_by_hostname("firewall")
assert firewall.dmz_port.enabled == True

View File

@@ -103,7 +103,7 @@ def test_shared_reward():
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
env.reset()

View File

@@ -0,0 +1,50 @@
import pytest
import yaml
from primaite.session.episode_schedule import ConstantEpisodeScheduler, EpisodeListScheduler
def test_episode_list_scheduler():
# Initialize an instance of EpisodeListScheduler
# Define a schedule and episode data for testing
schedule = {0: ["episode1"], 1: ["episode2"]}
episode_data = {"episode1": "data1: 1", "episode2": "data2: 2"}
base_scenario = """agents: []"""
scheduler = EpisodeListScheduler(schedule=schedule, episode_data=episode_data, base_scenario=base_scenario)
# Test when episode number is within the schedule
result = scheduler(0)
assert isinstance(result, dict)
assert yaml.safe_load("data1: 1\nagents: []") == result
# Test next episode
result = scheduler(1)
assert isinstance(result, dict)
assert yaml.safe_load("data2: 2\nagents: []") == result
# Test when episode number exceeds the schedule
result = scheduler(2)
assert isinstance(result, dict)
assert yaml.safe_load("data1: 1\nagents: []") == result
assert scheduler._exceeded_episode_list
# Test when episode number is a sequence
scheduler.schedule = {0: ["episode1", "episode2"]}
result = scheduler(0)
assert isinstance(result, dict)
assert yaml.safe_load("data1: 1\ndata2: 2\nagents: []") == result
def test_constant_episode_scheduler():
# Initialize an instance of ConstantEpisodeScheduler
config = {"key": "value"}
scheduler = ConstantEpisodeScheduler(config=config)
result = scheduler(0)
assert isinstance(result, dict)
assert {"key": "value"} == result
result = scheduler(1)
assert isinstance(result, dict)
assert {"key": "value"} == result