#2476 Add test for episode scheduler

This commit is contained in:
Marek Wolan
2024-04-25 15:09:46 +01:00
parent 42ce264e73
commit 66f31e8ed1
21 changed files with 456 additions and 23 deletions

View File

@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Upgraded pydantic to version 2.7.0
- Upgraded Ray to version >= 2.9
- Added ipywidgets to the dependencies
- Added ability to define scenarios that change depending on the episode number.
- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config`
## [Unreleased]
- Made requests fail to reach their target if the node is off

View File

@@ -227,7 +227,7 @@
"metadata": {},
"outputs": [],
"source": [
"env = PrimaiteGymEnv(game_config=scenario_path)"
"env = PrimaiteGymEnv(env_config=scenario_path)"
]
},
{

View File

@@ -1,4 +1,3 @@
import copy
import json
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
@@ -25,10 +24,10 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, game_config: Union[Dict, str, PathLike]):
def __init__(self, env_config: Union[Dict, str, PathLike]):
"""Initialise the environment."""
super().__init__()
self.episode_scheduler: EpisodeScheduler = build_scheduler(game_config)
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
@@ -140,8 +139,8 @@ class PrimaiteRayEnv(gymnasium.Env):
:param env_config: A dictionary containing the environment configuration.
:type env_config: Dict
"""
self.env = PrimaiteGymEnv(game_config=env_config)
self.env.episode_counter -= 1
self.env = PrimaiteGymEnv(env_config=env_config)
# self.env.episode_counter -= 1
self.action_space = self.env.action_space
self.observation_space = self.env.observation_space
@@ -157,6 +156,11 @@ class PrimaiteRayEnv(gymnasium.Env):
"""Close the simulation."""
self.env.close()
@property
def game(self) -> PrimaiteGame:
"""Pass through game from env."""
return self.env.game
class PrimaiteRayMARLEnv(MultiAgentEnv):
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
@@ -168,16 +172,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
which is the PrimaiteGame instance.
:type env_config: Dict
"""
self.game_config: Dict = env_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
self.episode_counter: int = 0
"""Current episode number."""
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
"""Reference to the primaite game"""
self._agent_ids = list(self.game.rl_agents.keys())
"""Agent ids. This is a list of strings of agent names."""
self.episode_counter: int = 0
"""Current episode number."""
self.terminateds = set()
self.truncateds = set()
@@ -203,9 +207,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
self.game.update_agents(state)
next_obs = self._get_obs()

View File

@@ -22,6 +22,7 @@ class EpisodeScheduler(pydantic.BaseModel, ABC):
@abstractmethod
def __call__(self, episode_num: int) -> Dict:
"""Return the config that should be used during this episode."""
...
class ConstantEpisodeScheduler(EpisodeScheduler):

View File

@@ -0,0 +1,2 @@
# No green agents present
greens: &greens []

View File

@@ -0,0 +1,34 @@
agents: &greens
- ref: green_A
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.2
1: 0.8
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DatabaseClient
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
reward_function:
reward_components:
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 1.0
options:
node_hostname: client

View File

@@ -0,0 +1,34 @@
agents: &greens
- ref: green_B
team: GREEN
type: ProbabilisticAgent
agent_settings:
action_probabilities:
0: 0.95
1: 0.05
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DatabaseClient
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_APPLICATION_EXECUTE
options:
node_id: 0
application_id: 0
reward_function:
reward_components:
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
weight: 1.0
options:
node_hostname: client

View File

@@ -0,0 +1,2 @@
# No red agents present
reds: &reds []

View File

@@ -0,0 +1,26 @@
reds: &reds
- ref: red_A
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DataManipulationBot
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 10
frequency: 10
variance: 0

View File

@@ -0,0 +1,26 @@
reds: &reds
- ref: red_B
team: RED
type: RedDatabaseCorruptingAgent
observation_space: null
action_space:
action_list:
- type: DONOTHING
- type: NODE_APPLICATION_EXECUTE
options:
nodes:
- node_name: client
applications:
- application_name: DataManipulationBot
reward_function:
reward_components:
- type: DUMMY
agent_settings:
start_settings:
start_step: 3
frequency: 2
variance: 1

View File

@@ -0,0 +1,168 @@
io_settings:
save_agent_actions: true
save_step_metadata: false
save_pcap_logs: false
save_sys_logs: false
game:
max_episode_length: 128
ports:
- HTTP
- POSTGRES_SERVER
protocols:
- ICMP
- TCP
- UDP
thresholds:
nmne:
high: 10
medium: 5
low: 0
agents:
- *greens
- *reds
- ref: defender
team: BLUE
type: ProxyAgent
observation_space:
type: CUSTOM
options:
components:
- type: NODES
label: NODES
options:
routers: []
hosts:
- hostname: client
- hostname: server
num_services: 1
num_applications: 1
num_folders: 1
num_files: 1
num_nics: 1
include_num_access: false
include_nmne: true
- type: LINKS
label: LINKS
options:
link_references:
- client:eth-1<->switch_1:eth-1
- server:eth-1<->switch_1:eth-2
action_space:
action_list:
- type: DONOTHING
- type: NODE_SHUTDOWN
- type: NODE_STARTUP
- type: HOST_NIC_ENABLE
- type: HOST_NIC_DISABLE
action_map:
0:
action: DONOTHING
options: {}
1:
action: NODE_SHUTDOWN
options:
node_id: 0
2:
action: NODE_SHUTDOWN
options:
node_id: 1
3:
action: NODE_STARTUP
options:
node_id: 0
4:
action: NODE_STARTUP
options:
node_id: 1
5:
action: HOST_NIC_DISABLE
options:
node_id: 0
nic_id: 0
6:
action: HOST_NIC_DISABLE
options:
node_id: 1
nic_id: 0
7:
action: HOST_NIC_ENABLE
options:
node_id: 0
nic_id: 0
8:
action: HOST_NIC_ENABLE
options:
node_id: 1
nic_id: 0
options:
nodes:
- node_name: client
- node_name: server
max_folders_per_node: 0
max_files_per_folder: 0
max_services_per_node: 0
max_nics_per_node: 1
max_acl_rules: 0
ip_list:
- 192.168.1.2
- 192.168.1.3
reward_function:
reward_components:
- type: DATABASE_FILE_INTEGRITY
weight: 0.40
options:
node_hostname: database_server
folder_name: database
file_name: database.db
agent_settings:
flatten_obs: false
simulation:
network:
nodes:
- hostname: client
type: computer
ip_address: 192.168.1.2
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
applications:
- type: DatabaseClient
options:
db_server_ip: 192.168.1.3
- type: DataManipulationBot
options:
server_ip: 192.168.1.3
payload: "DELETE"
- hostname: switch_1
type: switch
num_ports: 2
- hostname: server
type: server
ip_address: 192.168.1.3
subnet_mask: 255.255.255.0
default_gateway: 192.168.1.1
services:
- type: DatabaseService
links:
- endpoint_a_hostname: client
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 1
- endpoint_a_hostname: server
endpoint_a_port: 1
endpoint_b_hostname: switch_1
endpoint_b_port: 2

View File

@@ -0,0 +1,14 @@
base_scenario: scenario.yaml
schedule:
0:
- greens_0.yaml
- reds_0.yaml
1:
- greens_0.yaml
- reds_1.yaml
2:
- greens_1.yaml
- reds_1.yaml
3:
- greens_2.yaml
- reds_2.yaml

View File

@@ -16,7 +16,7 @@ def test_sb3_compatibility():
with open(data_manipulation_config_path(), "r") as f:
cfg = yaml.safe_load(f)
gym = PrimaiteGymEnv(game_config=cfg)
gym = PrimaiteGymEnv(env_config=cfg)
model = PPO("MlpPolicy", gym)
model.learn(total_timesteps=1000)

View File

@@ -21,7 +21,7 @@ class TestPrimaiteEnvironment:
"""Check that environment loads correctly from config and it can be reset."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
def env_checks():
assert env is not None
@@ -44,7 +44,7 @@ class TestPrimaiteEnvironment:
"""Make sure you can go all the way through the session without errors."""
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
assert (num_actions := len(env.agent.action_manager.action_map)) == 54
# run every action and make sure there's no crash
@@ -88,4 +88,4 @@ class TestPrimaiteEnvironment:
with open(MISCONFIGURED_PATH, "r") as f:
cfg = yaml.safe_load(f)
with pytest.raises(pydantic.ValidationError):
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)

View File

@@ -44,7 +44,7 @@ def test_application_install_uninstall_on_uc2():
with open(TEST_ASSETS_ROOT / "configs/test_application_install.yaml", "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
env.agent.flatten_obs = False
env.reset()

View File

@@ -0,0 +1,68 @@
import pytest
import yaml
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from tests.conftest import TEST_ASSETS_ROOT
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"
single_yaml_config = TEST_ASSETS_ROOT / "configs" / "test_primaite_session.yaml"
with open(single_yaml_config, "r") as f:
config_dict = yaml.safe_load(f)
@pytest.mark.parametrize("env_type", [PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv])
def test_creating_env_with_folder(env_type):
"""Check that the environment can be created with a folder path."""
def check_taking_steps(e):
if isinstance(e, PrimaiteRayMARLEnv):
for i in range(9):
e.step({k: i for k in e.game.rl_agents})
else:
for i in range(9):
e.step(i)
env = env_type(env_config=folder_path)
assert env is not None
for _ in range(3): # do it multiple times to ensure it loops back to the beginning
assert len(env.game.agents) == 1
assert "defender" in env.game.agents
check_taking_steps(env)
env.reset()
assert len(env.game.agents) == 2
assert "defender" in env.game.agents
assert "red_A" in env.game.agents
check_taking_steps(env)
env.reset()
assert len(env.game.agents) == 3
assert all([a in env.game.agents for a in ["defender", "green_A", "red_A"]])
check_taking_steps(env)
env.reset()
assert len(env.game.agents) == 3
assert all([a in env.game.agents for a in ["defender", "green_B", "red_B"]])
check_taking_steps(env)
env.reset()
@pytest.mark.parametrize(
"env_data, env_type",
[
(single_yaml_config, PrimaiteGymEnv),
(single_yaml_config, PrimaiteRayEnv),
(single_yaml_config, PrimaiteRayMARLEnv),
(config_dict, PrimaiteGymEnv),
(config_dict, PrimaiteRayEnv),
(config_dict, PrimaiteRayMARLEnv),
],
)
def test_creating_env_with_static_config(env_data, env_type):
"""Check that the environment can be created with a single yaml file."""
env = env_type(env_config=single_yaml_config)
assert env is not None
agents_before = len(env.game.agents)
env.reset()
assert len(env.game.agents) == agents_before

View File

@@ -24,7 +24,7 @@ def test_io_settings():
"""Test that the io_settings are loaded correctly."""
with open(BASIC_CONFIG, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
assert env.io.settings is not None

View File

@@ -507,7 +507,7 @@ def test_firewall_acl_add_remove_rule_integration():
with open(FIREWALL_ACTIONS_NETWORK, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
# 1: Check that traffic is normal and acl starts off with 4 rules.
firewall = env.game.simulation.network.get_node_by_hostname("firewall")
@@ -598,7 +598,7 @@ def test_firewall_port_disable_enable_integration():
with open(FIREWALL_ACTIONS_NETWORK, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
firewall = env.game.simulation.network.get_node_by_hostname("firewall")
assert firewall.dmz_port.enabled == True

View File

@@ -103,7 +103,7 @@ def test_shared_reward():
with open(CFG_PATH, "r") as f:
cfg = yaml.safe_load(f)
env = PrimaiteGymEnv(game_config=cfg)
env = PrimaiteGymEnv(env_config=cfg)
env.reset()

View File

@@ -0,0 +1,52 @@
# FILEPATH: /home/cade/repos/PrimAITE/tests/unit_tests/_primaite/_session/test_episode_schedule.py
import pytest
import yaml
from primaite.session.episode_schedule import ConstantEpisodeScheduler, EpisodeListScheduler
def test_episode_list_scheduler():
# Initialize an instance of EpisodeListScheduler
# Define a schedule and episode data for testing
schedule = {0: ["episode1"], 1: ["episode2"]}
episode_data = {"episode1": "data1: 1", "episode2": "data2: 2"}
base_scenario = """agents: []"""
scheduler = EpisodeListScheduler(schedule=schedule, episode_data=episode_data, base_scenario=base_scenario)
# Test when episode number is within the schedule
result = scheduler(0)
assert isinstance(result, dict)
assert yaml.safe_load("data1: 1\nagents: []") == result
# Test next episode
result = scheduler(1)
assert isinstance(result, dict)
assert yaml.safe_load("data2: 2\nagents: []") == result
# Test when episode number exceeds the schedule
result = scheduler(2)
assert isinstance(result, dict)
assert yaml.safe_load("data1: 1\nagents: []") == result
assert scheduler._exceeded_episode_list
# Test when episode number is a sequence
scheduler.schedule = {0: ["episode1", "episode2"]}
result = scheduler(0)
assert isinstance(result, dict)
assert yaml.safe_load("data1: 1\ndata2: 2\nagents: []") == result
def test_constant_episode_scheduler():
# Initialize an instance of ConstantEpisodeScheduler
config = {"key": "value"}
scheduler = ConstantEpisodeScheduler(config=config)
result = scheduler(0)
assert isinstance(result, dict)
assert {"key": "value"} == result
result = scheduler(1)
assert isinstance(result, dict)
assert {"key": "value"} == result