From 66f31e8ed1111fbfe71c2105d82d3ae22422a80d Mon Sep 17 00:00:00 2001 From: Marek Wolan Date: Thu, 25 Apr 2024 15:09:46 +0100 Subject: [PATCH] #2476 Add test for episode scheduler --- CHANGELOG.md | 2 + .../notebooks/Using-Episode-Schedules.ipynb | 2 +- src/primaite/session/environment.py | 30 ++-- src/primaite/session/episode_schedule.py | 1 + .../scenario_with_placeholders/greens_0.yaml | 2 + .../scenario_with_placeholders/greens_1.yaml | 34 ++++ .../scenario_with_placeholders/greens_2.yaml | 34 ++++ .../scenario_with_placeholders/reds_0.yaml | 2 + .../scenario_with_placeholders/reds_1.yaml | 26 +++ .../scenario_with_placeholders/reds_2.yaml | 26 +++ .../scenario_with_placeholders/scenario.yaml | 168 ++++++++++++++++++ .../scenario_with_placeholders/schedule.yaml | 14 ++ .../environments/test_sb3_environment.py | 2 +- .../e2e_integration_tests/test_environment.py | 6 +- .../test_uc2_data_manipulation_scenario.py | 2 +- .../test_episode_scheduler.py | 68 +++++++ .../test_io_settings.py | 2 +- .../game_layer/test_actions.py | 4 +- .../game_layer/test_rewards.py | 2 +- .../unit_tests/_primaite/_session/__init__.py | 0 .../_session/test_episode_schedule.py | 52 ++++++ 21 files changed, 456 insertions(+), 23 deletions(-) create mode 100644 tests/assets/configs/scenario_with_placeholders/greens_0.yaml create mode 100644 tests/assets/configs/scenario_with_placeholders/greens_1.yaml create mode 100644 tests/assets/configs/scenario_with_placeholders/greens_2.yaml create mode 100644 tests/assets/configs/scenario_with_placeholders/reds_0.yaml create mode 100644 tests/assets/configs/scenario_with_placeholders/reds_1.yaml create mode 100644 tests/assets/configs/scenario_with_placeholders/reds_2.yaml create mode 100644 tests/assets/configs/scenario_with_placeholders/scenario.yaml create mode 100644 tests/assets/configs/scenario_with_placeholders/schedule.yaml create mode 100644 tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py create mode 100644 tests/unit_tests/_primaite/_session/__init__.py create mode 100644 tests/unit_tests/_primaite/_session/test_episode_schedule.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 81fe5621..4147d6f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Upgraded pydantic to version 2.7.0 - Upgraded Ray to version >= 2.9 - Added ipywidgets to the dependencies +- Added ability to define scenarios that change depending on the episode number. +- Standardised Environment API by renaming the config parameter of `PrimaiteGymEnv` from `game_config` to `env_config` ## [Unreleased] - Made requests fail to reach their target if the node is off diff --git a/src/primaite/notebooks/Using-Episode-Schedules.ipynb b/src/primaite/notebooks/Using-Episode-Schedules.ipynb index 80e67065..c616a410 100644 --- a/src/primaite/notebooks/Using-Episode-Schedules.ipynb +++ b/src/primaite/notebooks/Using-Episode-Schedules.ipynb @@ -227,7 +227,7 @@ "metadata": {}, "outputs": [], "source": [ - "env = PrimaiteGymEnv(game_config=scenario_path)" + "env = PrimaiteGymEnv(env_config=scenario_path)" ] }, { diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index dea6b1dc..abbf051b 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,4 +1,3 @@ -import copy import json from os import PathLike from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union @@ -25,10 +24,10 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game_config: Union[Dict, str, PathLike]): + def __init__(self, env_config: Union[Dict, str, PathLike]): """Initialise the environment.""" super().__init__() - self.episode_scheduler: EpisodeScheduler = build_scheduler(game_config) + self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) """Object that returns a config corresponding to the current episode.""" self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" @@ -140,8 +139,8 @@ class PrimaiteRayEnv(gymnasium.Env): :param env_config: A dictionary containing the environment configuration. :type env_config: Dict """ - self.env = PrimaiteGymEnv(game_config=env_config) - self.env.episode_counter -= 1 + self.env = PrimaiteGymEnv(env_config=env_config) + # self.env.episode_counter -= 1 self.action_space = self.env.action_space self.observation_space = self.env.observation_space @@ -157,6 +156,11 @@ class PrimaiteRayEnv(gymnasium.Env): """Close the simulation.""" self.env.close() + @property + def game(self) -> PrimaiteGame: + """Pass through game from env.""" + return self.env.game + class PrimaiteRayMARLEnv(MultiAgentEnv): """Ray Environment that inherits from MultiAgentEnv to allow training MARL systems.""" @@ -168,16 +172,16 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): which is the PrimaiteGame instance. :type env_config: Dict """ - self.game_config: Dict = env_config - """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" - self.io = PrimaiteIO.from_config(env_config.get("io_settings")) + self.episode_counter: int = 0 + """Current episode number.""" + self.episode_scheduler: EpisodeScheduler = build_scheduler(env_config) + """Object that returns a config corresponding to the current episode.""" + self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) + self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) """Reference to the primaite game""" self._agent_ids = list(self.game.rl_agents.keys()) """Agent ids. This is a list of strings of agent names.""" - self.episode_counter: int = 0 - """Current episode number.""" self.terminateds = set() self.truncateds = set() @@ -203,9 +207,9 @@ class PrimaiteRayMARLEnv(MultiAgentEnv): if self.io.settings.save_agent_actions: all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) - self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) - self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 + self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(self.episode_counter)) + self.game.setup_for_episode(episode=self.episode_counter) state = self.game.get_sim_state() self.game.update_agents(state) next_obs = self._get_obs() diff --git a/src/primaite/session/episode_schedule.py b/src/primaite/session/episode_schedule.py index 69ae5778..fa010d27 100644 --- a/src/primaite/session/episode_schedule.py +++ b/src/primaite/session/episode_schedule.py @@ -22,6 +22,7 @@ class EpisodeScheduler(pydantic.BaseModel, ABC): @abstractmethod def __call__(self, episode_num: int) -> Dict: """Return the config that should be used during this episode.""" + ... class ConstantEpisodeScheduler(EpisodeScheduler): diff --git a/tests/assets/configs/scenario_with_placeholders/greens_0.yaml b/tests/assets/configs/scenario_with_placeholders/greens_0.yaml new file mode 100644 index 00000000..f31c52fa --- /dev/null +++ b/tests/assets/configs/scenario_with_placeholders/greens_0.yaml @@ -0,0 +1,2 @@ +# No green agents present +greens: &greens [] diff --git a/tests/assets/configs/scenario_with_placeholders/greens_1.yaml b/tests/assets/configs/scenario_with_placeholders/greens_1.yaml new file mode 100644 index 00000000..98d2392a --- /dev/null +++ b/tests/assets/configs/scenario_with_placeholders/greens_1.yaml @@ -0,0 +1,34 @@ +agents: &greens + - ref: green_A + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.2 + 1: 0.8 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client + applications: + - application_name: DatabaseClient + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + + reward_function: + reward_components: + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 1.0 + options: + node_hostname: client diff --git a/tests/assets/configs/scenario_with_placeholders/greens_2.yaml b/tests/assets/configs/scenario_with_placeholders/greens_2.yaml new file mode 100644 index 00000000..17a5977b --- /dev/null +++ b/tests/assets/configs/scenario_with_placeholders/greens_2.yaml @@ -0,0 +1,34 @@ +agents: &greens + - ref: green_B + team: GREEN + type: ProbabilisticAgent + agent_settings: + action_probabilities: + 0: 0.95 + 1: 0.05 + observation_space: null + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client + applications: + - application_name: DatabaseClient + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_APPLICATION_EXECUTE + options: + node_id: 0 + application_id: 0 + + reward_function: + reward_components: + - type: GREEN_ADMIN_DATABASE_UNREACHABLE_PENALTY + weight: 1.0 + options: + node_hostname: client diff --git a/tests/assets/configs/scenario_with_placeholders/reds_0.yaml b/tests/assets/configs/scenario_with_placeholders/reds_0.yaml new file mode 100644 index 00000000..878aba97 --- /dev/null +++ b/tests/assets/configs/scenario_with_placeholders/reds_0.yaml @@ -0,0 +1,2 @@ +# No red agents present +reds: &reds [] diff --git a/tests/assets/configs/scenario_with_placeholders/reds_1.yaml b/tests/assets/configs/scenario_with_placeholders/reds_1.yaml new file mode 100644 index 00000000..31675a0b --- /dev/null +++ b/tests/assets/configs/scenario_with_placeholders/reds_1.yaml @@ -0,0 +1,26 @@ +reds: &reds + - ref: red_A + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client + applications: + - application_name: DataManipulationBot + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 10 + frequency: 10 + variance: 0 diff --git a/tests/assets/configs/scenario_with_placeholders/reds_2.yaml b/tests/assets/configs/scenario_with_placeholders/reds_2.yaml new file mode 100644 index 00000000..c5572b89 --- /dev/null +++ b/tests/assets/configs/scenario_with_placeholders/reds_2.yaml @@ -0,0 +1,26 @@ +reds: &reds + - ref: red_B + team: RED + type: RedDatabaseCorruptingAgent + + observation_space: null + + action_space: + action_list: + - type: DONOTHING + - type: NODE_APPLICATION_EXECUTE + options: + nodes: + - node_name: client + applications: + - application_name: DataManipulationBot + + reward_function: + reward_components: + - type: DUMMY + + agent_settings: + start_settings: + start_step: 3 + frequency: 2 + variance: 1 diff --git a/tests/assets/configs/scenario_with_placeholders/scenario.yaml b/tests/assets/configs/scenario_with_placeholders/scenario.yaml new file mode 100644 index 00000000..81848b2d --- /dev/null +++ b/tests/assets/configs/scenario_with_placeholders/scenario.yaml @@ -0,0 +1,168 @@ +io_settings: + save_agent_actions: true + save_step_metadata: false + save_pcap_logs: false + save_sys_logs: false + + +game: + max_episode_length: 128 + ports: + - HTTP + - POSTGRES_SERVER + protocols: + - ICMP + - TCP + - UDP + thresholds: + nmne: + high: 10 + medium: 5 + low: 0 + +agents: + - *greens + - *reds + + - ref: defender + team: BLUE + type: ProxyAgent + observation_space: + type: CUSTOM + options: + components: + - type: NODES + label: NODES + options: + routers: [] + hosts: + - hostname: client + - hostname: server + num_services: 1 + num_applications: 1 + num_folders: 1 + num_files: 1 + num_nics: 1 + include_num_access: false + include_nmne: true + + - type: LINKS + label: LINKS + options: + link_references: + - client:eth-1<->switch_1:eth-1 + - server:eth-1<->switch_1:eth-2 + + action_space: + action_list: + - type: DONOTHING + - type: NODE_SHUTDOWN + - type: NODE_STARTUP + - type: HOST_NIC_ENABLE + - type: HOST_NIC_DISABLE + action_map: + 0: + action: DONOTHING + options: {} + 1: + action: NODE_SHUTDOWN + options: + node_id: 0 + 2: + action: NODE_SHUTDOWN + options: + node_id: 1 + 3: + action: NODE_STARTUP + options: + node_id: 0 + 4: + action: NODE_STARTUP + options: + node_id: 1 + 5: + action: HOST_NIC_DISABLE + options: + node_id: 0 + nic_id: 0 + 6: + action: HOST_NIC_DISABLE + options: + node_id: 1 + nic_id: 0 + 7: + action: HOST_NIC_ENABLE + options: + node_id: 0 + nic_id: 0 + 8: + action: HOST_NIC_ENABLE + options: + node_id: 1 + nic_id: 0 + options: + nodes: + - node_name: client + - node_name: server + + max_folders_per_node: 0 + max_files_per_folder: 0 + max_services_per_node: 0 + max_nics_per_node: 1 + max_acl_rules: 0 + ip_list: + - 192.168.1.2 + - 192.168.1.3 + + reward_function: + reward_components: + - type: DATABASE_FILE_INTEGRITY + weight: 0.40 + options: + node_hostname: database_server + folder_name: database + file_name: database.db + + agent_settings: + flatten_obs: false + + +simulation: + network: + nodes: + - hostname: client + type: computer + ip_address: 192.168.1.2 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + applications: + - type: DatabaseClient + options: + db_server_ip: 192.168.1.3 + - type: DataManipulationBot + options: + server_ip: 192.168.1.3 + payload: "DELETE" + + - hostname: switch_1 + type: switch + num_ports: 2 + + - hostname: server + type: server + ip_address: 192.168.1.3 + subnet_mask: 255.255.255.0 + default_gateway: 192.168.1.1 + services: + - type: DatabaseService + + links: + - endpoint_a_hostname: client + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 1 + + - endpoint_a_hostname: server + endpoint_a_port: 1 + endpoint_b_hostname: switch_1 + endpoint_b_port: 2 diff --git a/tests/assets/configs/scenario_with_placeholders/schedule.yaml b/tests/assets/configs/scenario_with_placeholders/schedule.yaml new file mode 100644 index 00000000..07ee4e50 --- /dev/null +++ b/tests/assets/configs/scenario_with_placeholders/schedule.yaml @@ -0,0 +1,14 @@ +base_scenario: scenario.yaml +schedule: + 0: + - greens_0.yaml + - reds_0.yaml + 1: + - greens_0.yaml + - reds_1.yaml + 2: + - greens_1.yaml + - reds_1.yaml + 3: + - greens_2.yaml + - reds_2.yaml diff --git a/tests/e2e_integration_tests/environments/test_sb3_environment.py b/tests/e2e_integration_tests/environments/test_sb3_environment.py index 83965191..f6ff595f 100644 --- a/tests/e2e_integration_tests/environments/test_sb3_environment.py +++ b/tests/e2e_integration_tests/environments/test_sb3_environment.py @@ -16,7 +16,7 @@ def test_sb3_compatibility(): with open(data_manipulation_config_path(), "r") as f: cfg = yaml.safe_load(f) - gym = PrimaiteGymEnv(game_config=cfg) + gym = PrimaiteGymEnv(env_config=cfg) model = PPO("MlpPolicy", gym) model.learn(total_timesteps=1000) diff --git a/tests/e2e_integration_tests/test_environment.py b/tests/e2e_integration_tests/test_environment.py index 673e1dc4..accfad50 100644 --- a/tests/e2e_integration_tests/test_environment.py +++ b/tests/e2e_integration_tests/test_environment.py @@ -21,7 +21,7 @@ class TestPrimaiteEnvironment: """Check that environment loads correctly from config and it can be reset.""" with open(CFG_PATH, "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(game_config=cfg) + env = PrimaiteGymEnv(env_config=cfg) def env_checks(): assert env is not None @@ -44,7 +44,7 @@ class TestPrimaiteEnvironment: """Make sure you can go all the way through the session without errors.""" with open(CFG_PATH, "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(game_config=cfg) + env = PrimaiteGymEnv(env_config=cfg) assert (num_actions := len(env.agent.action_manager.action_map)) == 54 # run every action and make sure there's no crash @@ -88,4 +88,4 @@ class TestPrimaiteEnvironment: with open(MISCONFIGURED_PATH, "r") as f: cfg = yaml.safe_load(f) with pytest.raises(pydantic.ValidationError): - env = PrimaiteGymEnv(game_config=cfg) + env = PrimaiteGymEnv(env_config=cfg) diff --git a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py index 0b31a353..db79e504 100644 --- a/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py +++ b/tests/e2e_integration_tests/test_uc2_data_manipulation_scenario.py @@ -44,7 +44,7 @@ def test_application_install_uninstall_on_uc2(): with open(TEST_ASSETS_ROOT / "configs/test_application_install.yaml", "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(game_config=cfg) + env = PrimaiteGymEnv(env_config=cfg) env.agent.flatten_obs = False env.reset() diff --git a/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py new file mode 100644 index 00000000..6b40fb1a --- /dev/null +++ b/tests/integration_tests/configuration_file_parsing/test_episode_scheduler.py @@ -0,0 +1,68 @@ +import pytest +import yaml + +from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv +from tests.conftest import TEST_ASSETS_ROOT + +folder_path = TEST_ASSETS_ROOT / "configs" / "scenario_with_placeholders" +single_yaml_config = TEST_ASSETS_ROOT / "configs" / "test_primaite_session.yaml" +with open(single_yaml_config, "r") as f: + config_dict = yaml.safe_load(f) + + +@pytest.mark.parametrize("env_type", [PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]) +def test_creating_env_with_folder(env_type): + """Check that the environment can be created with a folder path.""" + + def check_taking_steps(e): + if isinstance(e, PrimaiteRayMARLEnv): + for i in range(9): + e.step({k: i for k in e.game.rl_agents}) + else: + for i in range(9): + e.step(i) + + env = env_type(env_config=folder_path) + assert env is not None + for _ in range(3): # do it multiple times to ensure it loops back to the beginning + assert len(env.game.agents) == 1 + assert "defender" in env.game.agents + check_taking_steps(env) + + env.reset() + assert len(env.game.agents) == 2 + assert "defender" in env.game.agents + assert "red_A" in env.game.agents + check_taking_steps(env) + + env.reset() + assert len(env.game.agents) == 3 + assert all([a in env.game.agents for a in ["defender", "green_A", "red_A"]]) + check_taking_steps(env) + + env.reset() + assert len(env.game.agents) == 3 + assert all([a in env.game.agents for a in ["defender", "green_B", "red_B"]]) + check_taking_steps(env) + + env.reset() + + +@pytest.mark.parametrize( + "env_data, env_type", + [ + (single_yaml_config, PrimaiteGymEnv), + (single_yaml_config, PrimaiteRayEnv), + (single_yaml_config, PrimaiteRayMARLEnv), + (config_dict, PrimaiteGymEnv), + (config_dict, PrimaiteRayEnv), + (config_dict, PrimaiteRayMARLEnv), + ], +) +def test_creating_env_with_static_config(env_data, env_type): + """Check that the environment can be created with a single yaml file.""" + env = env_type(env_config=single_yaml_config) + assert env is not None + agents_before = len(env.game.agents) + env.reset() + assert len(env.game.agents) == agents_before diff --git a/tests/integration_tests/configuration_file_parsing/test_io_settings.py b/tests/integration_tests/configuration_file_parsing/test_io_settings.py index e66350cf..21f56e97 100644 --- a/tests/integration_tests/configuration_file_parsing/test_io_settings.py +++ b/tests/integration_tests/configuration_file_parsing/test_io_settings.py @@ -24,7 +24,7 @@ def test_io_settings(): """Test that the io_settings are loaded correctly.""" with open(BASIC_CONFIG, "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(game_config=cfg) + env = PrimaiteGymEnv(env_config=cfg) assert env.io.settings is not None diff --git a/tests/integration_tests/game_layer/test_actions.py b/tests/integration_tests/game_layer/test_actions.py index 855bc38d..edaf5d8d 100644 --- a/tests/integration_tests/game_layer/test_actions.py +++ b/tests/integration_tests/game_layer/test_actions.py @@ -507,7 +507,7 @@ def test_firewall_acl_add_remove_rule_integration(): with open(FIREWALL_ACTIONS_NETWORK, "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(game_config=cfg) + env = PrimaiteGymEnv(env_config=cfg) # 1: Check that traffic is normal and acl starts off with 4 rules. firewall = env.game.simulation.network.get_node_by_hostname("firewall") @@ -598,7 +598,7 @@ def test_firewall_port_disable_enable_integration(): with open(FIREWALL_ACTIONS_NETWORK, "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(game_config=cfg) + env = PrimaiteGymEnv(env_config=cfg) firewall = env.game.simulation.network.get_node_by_hostname("firewall") assert firewall.dmz_port.enabled == True diff --git a/tests/integration_tests/game_layer/test_rewards.py b/tests/integration_tests/game_layer/test_rewards.py index cfd013bc..7c38057e 100644 --- a/tests/integration_tests/game_layer/test_rewards.py +++ b/tests/integration_tests/game_layer/test_rewards.py @@ -103,7 +103,7 @@ def test_shared_reward(): with open(CFG_PATH, "r") as f: cfg = yaml.safe_load(f) - env = PrimaiteGymEnv(game_config=cfg) + env = PrimaiteGymEnv(env_config=cfg) env.reset() diff --git a/tests/unit_tests/_primaite/_session/__init__.py b/tests/unit_tests/_primaite/_session/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_tests/_primaite/_session/test_episode_schedule.py b/tests/unit_tests/_primaite/_session/test_episode_schedule.py new file mode 100644 index 00000000..5d28f24e --- /dev/null +++ b/tests/unit_tests/_primaite/_session/test_episode_schedule.py @@ -0,0 +1,52 @@ +# FILEPATH: /home/cade/repos/PrimAITE/tests/unit_tests/_primaite/_session/test_episode_schedule.py + +import pytest +import yaml + +from primaite.session.episode_schedule import ConstantEpisodeScheduler, EpisodeListScheduler + + +def test_episode_list_scheduler(): + # Initialize an instance of EpisodeListScheduler + + # Define a schedule and episode data for testing + schedule = {0: ["episode1"], 1: ["episode2"]} + episode_data = {"episode1": "data1: 1", "episode2": "data2: 2"} + base_scenario = """agents: []""" + + scheduler = EpisodeListScheduler(schedule=schedule, episode_data=episode_data, base_scenario=base_scenario) + # Test when episode number is within the schedule + result = scheduler(0) + assert isinstance(result, dict) + assert yaml.safe_load("data1: 1\nagents: []") == result + + # Test next episode + result = scheduler(1) + assert isinstance(result, dict) + assert yaml.safe_load("data2: 2\nagents: []") == result + + # Test when episode number exceeds the schedule + result = scheduler(2) + assert isinstance(result, dict) + assert yaml.safe_load("data1: 1\nagents: []") == result + assert scheduler._exceeded_episode_list + + # Test when episode number is a sequence + scheduler.schedule = {0: ["episode1", "episode2"]} + result = scheduler(0) + assert isinstance(result, dict) + assert yaml.safe_load("data1: 1\ndata2: 2\nagents: []") == result + + +def test_constant_episode_scheduler(): + # Initialize an instance of ConstantEpisodeScheduler + config = {"key": "value"} + scheduler = ConstantEpisodeScheduler(config=config) + + result = scheduler(0) + assert isinstance(result, dict) + assert {"key": "value"} == result + + result = scheduler(1) + assert isinstance(result, dict) + assert {"key": "value"} == result