#2476 Get episode schedule working
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
greens: &greens
|
||||
- ref: green_client_2
|
||||
- ref: green_A
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
@@ -48,7 +48,7 @@ greens: &greens
|
||||
options:
|
||||
node_hostname: client_2
|
||||
|
||||
- ref: green_client_1
|
||||
- ref: green_B
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
greens: &greens
|
||||
- ref: green_client_2
|
||||
- ref: green_C
|
||||
team: GREEN
|
||||
type: ProbabilisticAgent
|
||||
agent_settings:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
reds: &reds
|
||||
- ref: attacker_1
|
||||
- ref: red_A
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
reds: &reds
|
||||
- ref: attacker_2
|
||||
- ref: red_B
|
||||
team: RED
|
||||
type: RedDatabaseCorruptingAgent
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ game:
|
||||
agents:
|
||||
- *greens
|
||||
- *reds
|
||||
- *blue(s)
|
||||
|
||||
- ref: defender
|
||||
team: BLUE
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
base_scenario: scenario.yaml
|
||||
schedule:
|
||||
0:
|
||||
green: greens_1.yaml
|
||||
red: reds_1.yaml
|
||||
- greens_1.yaml
|
||||
- reds_1.yaml
|
||||
1:
|
||||
green: greens_1.yaml
|
||||
red: reds_2.yaml
|
||||
- greens_1.yaml
|
||||
- reds_2.yaml
|
||||
2:
|
||||
green: greens_2.yaml
|
||||
red: reds_1.yaml
|
||||
- greens_2.yaml
|
||||
- reds_1.yaml
|
||||
3:
|
||||
green: greens_2.yaml
|
||||
red: reds_2.yaml
|
||||
- greens_2.yaml
|
||||
- reds_2.yaml
|
||||
|
||||
# touch base with container to see what they've implemented for training schedule and evaluation schedule - for naming convention consistency
|
||||
# when you exceed the number of episodes defined in the yaml, raise a warning and loop back to the beginning
|
||||
|
||||
@@ -110,6 +110,54 @@
|
||||
"print(list(gym.game.agents.keys()))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from primaite.session.environment import PrimaiteGymEnv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env_2 = PrimaiteGymEnv(game_config='/home/cade/repos/PrimAITE/src/primaite/config/_package_data/scenario_with_placeholders')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"for i in range(10):\n",
|
||||
" print(env_2.episode_counter)\n",
|
||||
" print(list(env_2.game.agents.keys()))\n",
|
||||
" env_2.reset()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"env = PrimaiteGymEnv(game_config='/home/cade/repos/PrimAITE/src/primaite/config/_package_data/data_manipulation.yaml')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"sum([[1,2],[3,4]])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import json
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple
|
||||
from os import PathLike
|
||||
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
|
||||
|
||||
import gymnasium
|
||||
from gymnasium.core import ActType, ObsType
|
||||
@@ -9,6 +10,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
|
||||
from primaite import getLogger
|
||||
from primaite.game.agent.interface import ProxyAgent
|
||||
from primaite.game.game import PrimaiteGame
|
||||
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
|
||||
from primaite.session.io import PrimaiteIO
|
||||
from primaite.simulator import SIM_OUTPUT
|
||||
|
||||
@@ -23,17 +25,14 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
assumptions about the agent list always having a list of length 1.
|
||||
"""
|
||||
|
||||
def __init__(self, game_config: Dict):
|
||||
def __init__(self, game_config: Union[Dict, str, PathLike]):
|
||||
"""Initialise the environment."""
|
||||
super().__init__()
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
self.episode_scheduler: EpisodeScheduler = build_scheduler(game_config)
|
||||
"""Object that returns a config corresponding to the current episode."""
|
||||
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
|
||||
self.game_config: Dict = game_config
|
||||
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
|
||||
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
|
||||
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
|
||||
"""Current game."""
|
||||
self._agent_name = next(iter(self.game.rl_agents))
|
||||
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
|
||||
@@ -94,9 +93,9 @@ class PrimaiteGymEnv(gymnasium.Env):
|
||||
if self.io.settings.save_agent_actions:
|
||||
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
|
||||
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
self.episode_counter += 1
|
||||
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
|
||||
self.game.setup_for_episode(episode=self.episode_counter)
|
||||
state = self.game.get_sim_state()
|
||||
self.game.update_agents(state=state)
|
||||
next_obs = self._get_obs()
|
||||
|
||||
127
src/primaite/session/episode_schedule.py
Normal file
127
src/primaite/session/episode_schedule.py
Normal file
@@ -0,0 +1,127 @@
|
||||
import copy
|
||||
from abc import ABC, abstractmethod
|
||||
from os import PathLike
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Mapping, Sequence, Union
|
||||
|
||||
import pydantic
|
||||
|
||||
from primaite import getLogger
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
import warnings
|
||||
from itertools import chain
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
class EpisodeScheduler(pydantic.BaseModel, ABC):
|
||||
"""
|
||||
Episode schedulers provide functionality to select different scenarios and game setups for each episode.
|
||||
|
||||
This is useful when implementing advanced RL concepts like curriculum learning and domain randomisation.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, episode_num: int) -> Dict:
|
||||
"""Return the config that should be used during this episode."""
|
||||
|
||||
|
||||
class ConstantEpisodeScheduler(EpisodeScheduler):
|
||||
"""
|
||||
The constant episode schedule simply provides the same game setup every time.
|
||||
"""
|
||||
|
||||
config: Dict
|
||||
|
||||
def __call__(self, episode_num: int) -> Dict:
|
||||
"""Return the same config every time."""
|
||||
return copy.deepcopy(self.config)
|
||||
|
||||
|
||||
class EpisodeListScheduler(EpisodeScheduler):
|
||||
"""The episode list u"""
|
||||
|
||||
schedule: Mapping[int, List[str]]
|
||||
"""Mapping from episode number to list of filenames"""
|
||||
episode_data: Mapping[str, str]
|
||||
"""Mapping from filename to yaml string."""
|
||||
base_scenario: str
|
||||
"""yaml string containing the base scenario."""
|
||||
|
||||
_exceeded_episode_list: bool = False
|
||||
"""
|
||||
Flag that's set to true when attempting to keep generating episodes after schedule runs out.
|
||||
|
||||
When this happens, we loop back to the beginning, but a warning is raised.
|
||||
"""
|
||||
|
||||
# TODO: be careful about off-by-one errors with episode number- should it start at 0 or 1?
|
||||
def __call__(self, episode_num: int) -> Dict:
|
||||
if episode_num > len(self.schedule):
|
||||
if not self._exceeded_episode_list:
|
||||
self._exceeded_episode_list = True
|
||||
_LOGGER.warn(
|
||||
f"Running episode {episode_num} but the schedule only defines"
|
||||
f"{len(self.schedule)} episodes. Looping back to the beginning"
|
||||
)
|
||||
# not sure if we should be using a traditional warning, or a _LOGGER.warning
|
||||
episode_num = episode_num % len(self.schedule)
|
||||
|
||||
filenames_to_join = self.schedule[episode_num]
|
||||
yaml_data_to_join = [self.episode_data[fn] for fn in filenames_to_join] + [self.base_scenario]
|
||||
joined_yaml = "\n".join(yaml_data_to_join)
|
||||
parsed_cfg = yaml.safe_load(joined_yaml)
|
||||
|
||||
# Unfortunately, using placeholders like this is slightly hacky, so we have to flatten the list of agents
|
||||
flat_agents_list = []
|
||||
for a in parsed_cfg["agents"]:
|
||||
if isinstance(a, Sequence):
|
||||
flat_agents_list.extend(a)
|
||||
else:
|
||||
flat_agents_list.append(a)
|
||||
parsed_cfg["agents"] = flat_agents_list
|
||||
|
||||
return parsed_cfg
|
||||
|
||||
|
||||
def build_scheduler(config: Union[str, Path, Dict]) -> EpisodeScheduler:
|
||||
"""
|
||||
Convenience method to build an EpisodeScheduler with a dict, file path, or folder path.
|
||||
|
||||
If a path to a folder is provided, it will be treated as a list of game scenarios.
|
||||
Otherwise, if a dict or a single file is provided, it will be treated as a constant game scenario.
|
||||
"""
|
||||
# If we get a dict, return a constant episode schedule that repeats that one config forever
|
||||
if isinstance(config, Dict):
|
||||
return ConstantEpisodeScheduler(config=config)
|
||||
|
||||
# Cast string to Path
|
||||
if isinstance(config, str):
|
||||
config = Path(config)
|
||||
|
||||
if not config.exists():
|
||||
raise FileNotFoundError(f"Provided config path {config} could not be found.")
|
||||
|
||||
if config.is_file():
|
||||
with open(config, "r") as f:
|
||||
cfg_data = yaml.safe_load(f)
|
||||
return ConstantEpisodeScheduler(config=cfg_data)
|
||||
|
||||
if not config.is_dir():
|
||||
raise RuntimeError("Something went wrong while building Primaite config.")
|
||||
|
||||
root = config
|
||||
schedule_path = root / "schedule.yaml"
|
||||
|
||||
with open(schedule_path, "r") as f:
|
||||
schedule = yaml.safe_load(f)
|
||||
|
||||
base_scenario_path = root / schedule["base_scenario"]
|
||||
files_to_load = set(chain.from_iterable(schedule["schedule"].values()))
|
||||
|
||||
episode_data = {fp: (root / fp).read_text() for fp in files_to_load}
|
||||
|
||||
return EpisodeListScheduler(
|
||||
schedule=schedule["schedule"], episode_data=episode_data, base_scenario=base_scenario_path.read_text()
|
||||
)
|
||||
Reference in New Issue
Block a user