diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/greens_1.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/greens_1.yaml index 2702cbe6..e152f23f 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/greens_1.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/greens_1.yaml @@ -1,5 +1,5 @@ greens: &greens - - ref: green_client_2 + - ref: green_A team: GREEN type: ProbabilisticAgent agent_settings: @@ -48,7 +48,7 @@ greens: &greens options: node_hostname: client_2 - - ref: green_client_1 + - ref: green_B team: GREEN type: ProbabilisticAgent agent_settings: diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/greens_2.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/greens_2.yaml index e0c33656..87c8ffe3 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/greens_2.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/greens_2.yaml @@ -1,5 +1,5 @@ greens: &greens - - ref: green_client_2 + - ref: green_C team: GREEN type: ProbabilisticAgent agent_settings: diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/reds_1.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/reds_1.yaml index f41fca8d..9019f6c6 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/reds_1.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/reds_1.yaml @@ -1,5 +1,5 @@ reds: &reds - - ref: attacker_1 + - ref: red_A team: RED type: RedDatabaseCorruptingAgent diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/reds_2.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/reds_2.yaml index 13e1dd3b..c3304e17 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/reds_2.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/reds_2.yaml @@ -1,5 +1,5 @@ reds: &reds - - ref: attacker_2 + - ref: red_B team: RED type: RedDatabaseCorruptingAgent diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml index 426b79c7..b3d47f78 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/scenario.yaml @@ -23,7 +23,6 @@ game: agents: - *greens - *reds - - *blue(s) - ref: defender team: BLUE diff --git a/src/primaite/config/_package_data/scenario_with_placeholders/schedule.yaml b/src/primaite/config/_package_data/scenario_with_placeholders/schedule.yaml index 866c9895..2d26eb31 100644 --- a/src/primaite/config/_package_data/scenario_with_placeholders/schedule.yaml +++ b/src/primaite/config/_package_data/scenario_with_placeholders/schedule.yaml @@ -1,17 +1,17 @@ base_scenario: scenario.yaml schedule: 0: - green: greens_1.yaml - red: reds_1.yaml + - greens_1.yaml + - reds_1.yaml 1: - green: greens_1.yaml - red: reds_2.yaml + - greens_1.yaml + - reds_2.yaml 2: - green: greens_2.yaml - red: reds_1.yaml + - greens_2.yaml + - reds_1.yaml 3: - green: greens_2.yaml - red: reds_2.yaml + - greens_2.yaml + - reds_2.yaml # touch base with container to see what they've implemented for training schedule and evaluation schedule - for naming convention consistency # when you exceed the number of episodes defined in the yaml, raise a warning and loop back to the beginning diff --git a/src/primaite/notebooks/Scenario-Placeholders.ipynb b/src/primaite/notebooks/Scenario-Placeholders.ipynb index 67835999..9de34a81 100644 --- a/src/primaite/notebooks/Scenario-Placeholders.ipynb +++ b/src/primaite/notebooks/Scenario-Placeholders.ipynb @@ -110,6 +110,54 @@ "print(list(gym.game.agents.keys()))" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from primaite.session.environment import PrimaiteGymEnv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env_2 = PrimaiteGymEnv(game_config='/home/cade/repos/PrimAITE/src/primaite/config/_package_data/scenario_with_placeholders')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(10):\n", + " print(env_2.episode_counter)\n", + " print(list(env_2.game.agents.keys()))\n", + " env_2.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "env = PrimaiteGymEnv(game_config='/home/cade/repos/PrimAITE/src/primaite/config/_package_data/data_manipulation.yaml')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sum([[1,2],[3,4]])" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/src/primaite/session/environment.py b/src/primaite/session/environment.py index 9311e1f7..dea6b1dc 100644 --- a/src/primaite/session/environment.py +++ b/src/primaite/session/environment.py @@ -1,6 +1,7 @@ import copy import json -from typing import Any, Dict, Optional, SupportsFloat, Tuple +from os import PathLike +from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union import gymnasium from gymnasium.core import ActType, ObsType @@ -9,6 +10,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv from primaite import getLogger from primaite.game.agent.interface import ProxyAgent from primaite.game.game import PrimaiteGame +from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler from primaite.session.io import PrimaiteIO from primaite.simulator import SIM_OUTPUT @@ -23,17 +25,14 @@ class PrimaiteGymEnv(gymnasium.Env): assumptions about the agent list always having a list of length 1. """ - def __init__(self, game_config: Dict): + def __init__(self, game_config: Union[Dict, str, PathLike]): """Initialise the environment.""" super().__init__() - self.io = PrimaiteIO.from_config(game_config.get("io_settings", {})) + self.episode_scheduler: EpisodeScheduler = build_scheduler(game_config) + """Object that returns a config corresponding to the current episode.""" + self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {})) """Handles IO for the environment. This produces sys logs, agent logs, etc.""" - - self.game_config: Dict = game_config - """PrimaiteGame definition. This can be changed between episodes to enable curriculum learning.""" - self.io = PrimaiteIO.from_config(game_config.get("io_settings", {})) - """Handles IO for the environment. This produces sys logs, agent logs, etc.""" - self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config)) + self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0)) """Current game.""" self._agent_name = next(iter(self.game.rl_agents)) """Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key.""" @@ -94,9 +93,9 @@ class PrimaiteGymEnv(gymnasium.Env): if self.io.settings.save_agent_actions: all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()} self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter) - self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config)) - self.game.setup_for_episode(episode=self.episode_counter) self.episode_counter += 1 + self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter)) + self.game.setup_for_episode(episode=self.episode_counter) state = self.game.get_sim_state() self.game.update_agents(state=state) next_obs = self._get_obs() diff --git a/src/primaite/session/episode_schedule.py b/src/primaite/session/episode_schedule.py new file mode 100644 index 00000000..2245e2b5 --- /dev/null +++ b/src/primaite/session/episode_schedule.py @@ -0,0 +1,127 @@ +import copy +from abc import ABC, abstractmethod +from os import PathLike +from pathlib import Path +from typing import Dict, List, Mapping, Sequence, Union + +import pydantic + +from primaite import getLogger + +_LOGGER = getLogger(__name__) +import warnings +from itertools import chain + +import yaml + + +class EpisodeScheduler(pydantic.BaseModel, ABC): + """ + Episode schedulers provide functionality to select different scenarios and game setups for each episode. + + This is useful when implementing advanced RL concepts like curriculum learning and domain randomisation. + """ + + @abstractmethod + def __call__(self, episode_num: int) -> Dict: + """Return the config that should be used during this episode.""" + + +class ConstantEpisodeScheduler(EpisodeScheduler): + """ + The constant episode schedule simply provides the same game setup every time. + """ + + config: Dict + + def __call__(self, episode_num: int) -> Dict: + """Return the same config every time.""" + return copy.deepcopy(self.config) + + +class EpisodeListScheduler(EpisodeScheduler): + """The episode list u""" + + schedule: Mapping[int, List[str]] + """Mapping from episode number to list of filenames""" + episode_data: Mapping[str, str] + """Mapping from filename to yaml string.""" + base_scenario: str + """yaml string containing the base scenario.""" + + _exceeded_episode_list: bool = False + """ + Flag that's set to true when attempting to keep generating episodes after schedule runs out. + + When this happens, we loop back to the beginning, but a warning is raised. + """ + + # TODO: be careful about off-by-one errors with episode number- should it start at 0 or 1? + def __call__(self, episode_num: int) -> Dict: + if episode_num > len(self.schedule): + if not self._exceeded_episode_list: + self._exceeded_episode_list = True + _LOGGER.warn( + f"Running episode {episode_num} but the schedule only defines" + f"{len(self.schedule)} episodes. Looping back to the beginning" + ) + # not sure if we should be using a traditional warning, or a _LOGGER.warning + episode_num = episode_num % len(self.schedule) + + filenames_to_join = self.schedule[episode_num] + yaml_data_to_join = [self.episode_data[fn] for fn in filenames_to_join] + [self.base_scenario] + joined_yaml = "\n".join(yaml_data_to_join) + parsed_cfg = yaml.safe_load(joined_yaml) + + # Unfortunately, using placeholders like this is slightly hacky, so we have to flatten the list of agents + flat_agents_list = [] + for a in parsed_cfg["agents"]: + if isinstance(a, Sequence): + flat_agents_list.extend(a) + else: + flat_agents_list.append(a) + parsed_cfg["agents"] = flat_agents_list + + return parsed_cfg + + +def build_scheduler(config: Union[str, Path, Dict]) -> EpisodeScheduler: + """ + Convenience method to build an EpisodeScheduler with a dict, file path, or folder path. + + If a path to a folder is provided, it will be treated as a list of game scenarios. + Otherwise, if a dict or a single file is provided, it will be treated as a constant game scenario. + """ + # If we get a dict, return a constant episode schedule that repeats that one config forever + if isinstance(config, Dict): + return ConstantEpisodeScheduler(config=config) + + # Cast string to Path + if isinstance(config, str): + config = Path(config) + + if not config.exists(): + raise FileNotFoundError(f"Provided config path {config} could not be found.") + + if config.is_file(): + with open(config, "r") as f: + cfg_data = yaml.safe_load(f) + return ConstantEpisodeScheduler(config=cfg_data) + + if not config.is_dir(): + raise RuntimeError("Something went wrong while building Primaite config.") + + root = config + schedule_path = root / "schedule.yaml" + + with open(schedule_path, "r") as f: + schedule = yaml.safe_load(f) + + base_scenario_path = root / schedule["base_scenario"] + files_to_load = set(chain.from_iterable(schedule["schedule"].values())) + + episode_data = {fp: (root / fp).read_text() for fp in files_to_load} + + return EpisodeListScheduler( + schedule=schedule["schedule"], episode_data=episode_data, base_scenario=base_scenario_path.read_text() + )