#2476 Add test for episode scheduler
This commit is contained in:
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Upgraded pydantic to version 2.7.0
|
||||
- Upgraded Ray to version >= 2.9
|
||||
- Added ipywidgets to the dependencies
|
||||
- Added ability to define scenarios that change depending on the episode number.
|
||||
- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config`
|
||||
|
||||
## [Unreleased]
|
||||
- Made requests fail to reach their target if the node is off
|
||||
|
||||
@@ -227,7 +227,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = PrimaiteGymEnv(game_config=scenario_path)"
|
||||
"env = PrimaiteGymEnv(env_config=scenario_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import json
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
@@ -25,10 +24,10 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, game_config: Union[Dict, str, PathLike]):
|
||||
def __init__(self, env_config: Union[Dict, str, PathLike]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(game_config)
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
@@ -140,8 +139,8 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
:param env_config: A dictionary containing the environment configuration.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.env = PrimaiteGymEnv(game_config=env_config)
|
||||
self.env.episode_counter -= 1
|
||||
self.env = PrimaiteGymEnv(env_config=env_config)
|
||||
# self.env.episode_counter -= 1
|
||||
self.action_space = self.env.action_space
|
||||
self.observation_space = self.env.observation_space
|
||||
|
||||
@@ -157,6 +156,11 @@ class PrimaiteRayEnv(gymnasium.Env):
|
||||
"""Close the simulation."""
|
||||
self.env.close()
|
||||
|
||||
@property
|
||||
def game(self) -> PrimaiteGame:
|
||||
"""Pass through game from env."""
|
||||
return self.env.game
|
||||
|
||||
|
||||
class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
"""Ray Environment that inherits from MultiAgentEnv to allow training MARL systems."""
|
||||
@@ -168,16 +172,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
which is the PrimaiteGame instance.
|
||||
:type env_config: Dict
|
||||
"""
|
||||
self.game_config: Dict = env_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.io = PrimaiteIO.from_config(env_config.get("io_settings"))
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
"""Reference to the primaite game"""
|
||||
self._agent_ids = list(self.game.rl_agents.keys())
|
||||
"""Agent ids. This is a list of strings of agent names."""
|
||||
self.episode_counter: int = 0
|
||||
"""Current episode number."""
|
||||
|
||||
self.terminateds = set()
|
||||
self.truncateds = set()
|
||||
@@ -203,9 +207,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv):
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
@@ -22,6 +22,7 @@ class EpisodeScheduler(pydantic.BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def __call__(self, episode_num: int) -> Dict:
|
||||
"""Return the config that should be used during this episode."""
|
||||
...
|
||||
|
||||
|
||||
class ConstantEpisodeScheduler(EpisodeScheduler):
|
||||
|
||||
@@ -0,0 +1,2 @@
|
||||
# No green agents present
|
||||
greens: &greens []
|
||||
@@ -0,0 +1,34 @@
|
||||
agents: &greens
|
||||
- ref: green_A
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.2
|
||||
1: 0.8
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client
|
||||
@@ -0,0 +1,34 @@
|
||||
agents: &greens
|
||||
- ref: green_B
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
action_probabilities:
|
||||
0: 0.95
|
||||
1: 0.05
|
||||
observation_space: null
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DatabaseClient
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
node_id: 0
|
||||
application_id: 0
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY
|
||||
weight: 1.0
|
||||
options:
|
||||
node_hostname: client
|
||||
@@ -0,0 +1,2 @@
|
||||
# No red agents present
|
||||
reds: &reds []
|
||||
26
tests/assets/configs/scenario_with_placeholders/reds_1.yaml
Normal file
26
tests/assets/configs/scenario_with_placeholders/reds_1.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
reds: &reds
|
||||
- ref: red_A
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 10
|
||||
frequency: 10
|
||||
variance: 0
|
||||
26
tests/assets/configs/scenario_with_placeholders/reds_2.yaml
Normal file
26
tests/assets/configs/scenario_with_placeholders/reds_2.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
reds: &reds
|
||||
- ref: red_B
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
observation_space: null
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_APPLICATION_EXECUTE
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
applications:
|
||||
- application_name: DataManipulationBot
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DUMMY
|
||||
|
||||
agent_settings:
|
||||
start_settings:
|
||||
start_step: 3
|
||||
frequency: 2
|
||||
variance: 1
|
||||
168
tests/assets/configs/scenario_with_placeholders/scenario.yaml
Normal file
168
tests/assets/configs/scenario_with_placeholders/scenario.yaml
Normal file
@@ -0,0 +1,168 @@
|
||||
io_settings:
|
||||
save_agent_actions: true
|
||||
save_step_metadata: false
|
||||
save_pcap_logs: false
|
||||
save_sys_logs: false
|
||||
|
||||
|
||||
game:
|
||||
max_episode_length: 128
|
||||
ports:
|
||||
- HTTP
|
||||
- POSTGRES_SERVER
|
||||
protocols:
|
||||
- ICMP
|
||||
- TCP
|
||||
- UDP
|
||||
thresholds:
|
||||
nmne:
|
||||
high: 10
|
||||
medium: 5
|
||||
low: 0
|
||||
|
||||
agents:
|
||||
- *greens
|
||||
- *reds
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
type: ProxyAgent
|
||||
observation_space:
|
||||
type: CUSTOM
|
||||
options:
|
||||
components:
|
||||
- type: NODES
|
||||
label: NODES
|
||||
options:
|
||||
routers: []
|
||||
hosts:
|
||||
- hostname: client
|
||||
- hostname: server
|
||||
num_services: 1
|
||||
num_applications: 1
|
||||
num_folders: 1
|
||||
num_files: 1
|
||||
num_nics: 1
|
||||
include_num_access: false
|
||||
include_nmne: true
|
||||
|
||||
- type: LINKS
|
||||
label: LINKS
|
||||
options:
|
||||
link_references:
|
||||
- client:eth-1<->switch_1:eth-1
|
||||
- server:eth-1<->switch_1:eth-2
|
||||
|
||||
action_space:
|
||||
action_list:
|
||||
- type: DONOTHING
|
||||
- type: NODE_SHUTDOWN
|
||||
- type: NODE_STARTUP
|
||||
- type: HOST_NIC_ENABLE
|
||||
- type: HOST_NIC_DISABLE
|
||||
action_map:
|
||||
0:
|
||||
action: DONOTHING
|
||||
options: {}
|
||||
1:
|
||||
action: NODE_SHUTDOWN
|
||||
options:
|
||||
node_id: 0
|
||||
2:
|
||||
action: NODE_SHUTDOWN
|
||||
options:
|
||||
node_id: 1
|
||||
3:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 0
|
||||
4:
|
||||
action: NODE_STARTUP
|
||||
options:
|
||||
node_id: 1
|
||||
5:
|
||||
action: HOST_NIC_DISABLE
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
6:
|
||||
action: HOST_NIC_DISABLE
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
7:
|
||||
action: HOST_NIC_ENABLE
|
||||
options:
|
||||
node_id: 0
|
||||
nic_id: 0
|
||||
8:
|
||||
action: HOST_NIC_ENABLE
|
||||
options:
|
||||
node_id: 1
|
||||
nic_id: 0
|
||||
options:
|
||||
nodes:
|
||||
- node_name: client
|
||||
- node_name: server
|
||||
|
||||
max_folders_per_node: 0
|
||||
max_files_per_folder: 0
|
||||
max_services_per_node: 0
|
||||
max_nics_per_node: 1
|
||||
max_acl_rules: 0
|
||||
ip_list:
|
||||
- 192.168.1.2
|
||||
- 192.168.1.3
|
||||
|
||||
reward_function:
|
||||
reward_components:
|
||||
- type: DATABASE_FILE_INTEGRITY
|
||||
weight: 0.40
|
||||
options:
|
||||
node_hostname: database_server
|
||||
folder_name: database
|
||||
file_name: database.db
|
||||
|
||||
agent_settings:
|
||||
flatten_obs: false
|
||||
|
||||
|
||||
simulation:
|
||||
network:
|
||||
nodes:
|
||||
- hostname: client
|
||||
type: computer
|
||||
ip_address: 192.168.1.2
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
applications:
|
||||
- type: DatabaseClient
|
||||
options:
|
||||
db_server_ip: 192.168.1.3
|
||||
- type: DataManipulationBot
|
||||
options:
|
||||
server_ip: 192.168.1.3
|
||||
payload: "DELETE"
|
||||
|
||||
- hostname: switch_1
|
||||
type: switch
|
||||
num_ports: 2
|
||||
|
||||
- hostname: server
|
||||
type: server
|
||||
ip_address: 192.168.1.3
|
||||
subnet_mask: 255.255.255.0
|
||||
default_gateway: 192.168.1.1
|
||||
services:
|
||||
- type: DatabaseService
|
||||
|
||||
links:
|
||||
- endpoint_a_hostname: client
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 1
|
||||
|
||||
- endpoint_a_hostname: server
|
||||
endpoint_a_port: 1
|
||||
endpoint_b_hostname: switch_1
|
||||
endpoint_b_port: 2
|
||||
@@ -0,0 +1,14 @@
|
||||
base_scenario: scenario.yaml
|
||||
schedule:
|
||||
0:
|
||||
- greens_0.yaml
|
||||
- reds_0.yaml
|
||||
1:
|
||||
- greens_0.yaml
|
||||
- reds_1.yaml
|
||||
2:
|
||||
- greens_1.yaml
|
||||
- reds_1.yaml
|
||||
3:
|
||||
- greens_2.yaml
|
||||
- reds_2.yaml
|
||||
@@ -16,7 +16,7 @@ def test_sb3_compatibility():
|
||||
with open(data_manipulation_config_path(), "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
gym = PrimaiteGymEnv(game_config=cfg)
|
||||
gym = PrimaiteGymEnv(env_config=cfg)
|
||||
model = PPO("MlpPolicy", gym)
|
||||
|
||||
model.learn(total_timesteps=1000)
|
||||
|
||||
@@ -21,7 +21,7 @@ class TestPrimaiteEnvironment:
|
||||
"""Check that environment loads correctly from config and it can be reset."""
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
def env_checks():
|
||||
assert env is not None
|
||||
@@ -44,7 +44,7 @@ class TestPrimaiteEnvironment:
|
||||
"""Make sure you can go all the way through the session without errors."""
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
assert (num_actions := len(env.agent.action_manager.action_map)) == 54
|
||||
# run every action and make sure there's no crash
|
||||
@@ -88,4 +88,4 @@ class TestPrimaiteEnvironment:
|
||||
with open(MISCONFIGURED_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
with pytest.raises(pydantic.ValidationError):
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
@@ -44,7 +44,7 @@ def test_application_install_uninstall_on_uc2():
|
||||
with open(TEST_ASSETS_ROOT / "configs/test_application_install.yaml", "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
env.agent.flatten_obs = False
|
||||
env.reset()
|
||||
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
|
||||
from tests.conftest import TEST_ASSETS_ROOT
|
||||
|
||||
folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders"
|
||||
single_yaml_config = TEST_ASSETS_ROOT / "configs" / "test_primaite_session.yaml"
|
||||
with open(single_yaml_config, "r") as f:
|
||||
config_dict = yaml.safe_load(f)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_type", [PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv])
|
||||
def test_creating_env_with_folder(env_type):
|
||||
"""Check that the environment can be created with a folder path."""
|
||||
|
||||
def check_taking_steps(e):
|
||||
if isinstance(e, PrimaiteRayMARLEnv):
|
||||
for i in range(9):
|
||||
e.step({k: i for k in e.game.rl_agents})
|
||||
else:
|
||||
for i in range(9):
|
||||
e.step(i)
|
||||
|
||||
env = env_type(env_config=folder_path)
|
||||
assert env is not None
|
||||
for _ in range(3): # do it multiple times to ensure it loops back to the beginning
|
||||
assert len(env.game.agents) == 1
|
||||
assert "defender" in env.game.agents
|
||||
check_taking_steps(env)
|
||||
|
||||
env.reset()
|
||||
assert len(env.game.agents) == 2
|
||||
assert "defender" in env.game.agents
|
||||
assert "red_A" in env.game.agents
|
||||
check_taking_steps(env)
|
||||
|
||||
env.reset()
|
||||
assert len(env.game.agents) == 3
|
||||
assert all([a in env.game.agents for a in ["defender", "green_A", "red_A"]])
|
||||
check_taking_steps(env)
|
||||
|
||||
env.reset()
|
||||
assert len(env.game.agents) == 3
|
||||
assert all([a in env.game.agents for a in ["defender", "green_B", "red_B"]])
|
||||
check_taking_steps(env)
|
||||
|
||||
env.reset()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"env_data, env_type",
|
||||
[
|
||||
(single_yaml_config, PrimaiteGymEnv),
|
||||
(single_yaml_config, PrimaiteRayEnv),
|
||||
(single_yaml_config, PrimaiteRayMARLEnv),
|
||||
(config_dict, PrimaiteGymEnv),
|
||||
(config_dict, PrimaiteRayEnv),
|
||||
(config_dict, PrimaiteRayMARLEnv),
|
||||
],
|
||||
)
|
||||
def test_creating_env_with_static_config(env_data, env_type):
|
||||
"""Check that the environment can be created with a single yaml file."""
|
||||
env = env_type(env_config=single_yaml_config)
|
||||
assert env is not None
|
||||
agents_before = len(env.game.agents)
|
||||
env.reset()
|
||||
assert len(env.game.agents) == agents_before
|
||||
@@ -24,7 +24,7 @@ def test_io_settings():
|
||||
"""Test that the io_settings are loaded correctly."""
|
||||
with open(BASIC_CONFIG, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
assert env.io.settings is not None
|
||||
|
||||
|
||||
@@ -507,7 +507,7 @@ def test_firewall_acl_add_remove_rule_integration():
|
||||
with open(FIREWALL_ACTIONS_NETWORK, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
# 1: Check that traffic is normal and acl starts off with 4 rules.
|
||||
firewall = env.game.simulation.network.get_node_by_hostname("firewall")
|
||||
@@ -598,7 +598,7 @@ def test_firewall_port_disable_enable_integration():
|
||||
with open(FIREWALL_ACTIONS_NETWORK, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
firewall = env.game.simulation.network.get_node_by_hostname("firewall")
|
||||
|
||||
assert firewall.dmz_port.enabled == True
|
||||
|
||||
@@ -103,7 +103,7 @@ def test_shared_reward():
|
||||
with open(CFG_PATH, "r") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
env = PrimaiteGymEnv(game_config=cfg)
|
||||
env = PrimaiteGymEnv(env_config=cfg)
|
||||
|
||||
env.reset()
|
||||
|
||||
|
||||
0
tests/unit_tests/_primaite/_session/__init__.py
Normal file
0
tests/unit_tests/_primaite/_session/__init__.py
Normal file
52
tests/unit_tests/_primaite/_session/test_episode_schedule.py
Normal file
52
tests/unit_tests/_primaite/_session/test_episode_schedule.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# FILEPATH: /home/cade/repos/PrimAITE/tests/unit_tests/_primaite/_session/test_episode_schedule.py
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
|
||||
from primaite.session.episode_schedule import ConstantEpisodeScheduler, EpisodeListScheduler
|
||||
|
||||
|
||||
def test_episode_list_scheduler():
|
||||
# Initialize an instance of EpisodeListScheduler
|
||||
|
||||
# Define a schedule and episode data for testing
|
||||
schedule = {0: ["episode1"], 1: ["episode2"]}
|
||||
episode_data = {"episode1": "data1: 1", "episode2": "data2: 2"}
|
||||
base_scenario = """agents: []"""
|
||||
|
||||
scheduler = EpisodeListScheduler(schedule=schedule, episode_data=episode_data, base_scenario=base_scenario)
|
||||
# Test when episode number is within the schedule
|
||||
result = scheduler(0)
|
||||
assert isinstance(result, dict)
|
||||
assert yaml.safe_load("data1: 1\nagents: []") == result
|
||||
|
||||
# Test next episode
|
||||
result = scheduler(1)
|
||||
assert isinstance(result, dict)
|
||||
assert yaml.safe_load("data2: 2\nagents: []") == result
|
||||
|
||||
# Test when episode number exceeds the schedule
|
||||
result = scheduler(2)
|
||||
assert isinstance(result, dict)
|
||||
assert yaml.safe_load("data1: 1\nagents: []") == result
|
||||
assert scheduler._exceeded_episode_list
|
||||
|
||||
# Test when episode number is a sequence
|
||||
scheduler.schedule = {0: ["episode1", "episode2"]}
|
||||
result = scheduler(0)
|
||||
assert isinstance(result, dict)
|
||||
assert yaml.safe_load("data1: 1\ndata2: 2\nagents: []") == result
|
||||
|
||||
|
||||
def test_constant_episode_scheduler():
|
||||
# Initialize an instance of ConstantEpisodeScheduler
|
||||
config = {"key": "value"}
|
||||
scheduler = ConstantEpisodeScheduler(config=config)
|
||||
|
||||
result = scheduler(0)
|
||||
assert isinstance(result, dict)
|
||||
assert {"key": "value"} == result
|
||||
|
||||
result = scheduler(1)
|
||||
assert isinstance(result, dict)
|
||||
assert {"key": "value"} == result
|
||||
Reference in New Issue
Block a user