#2476 Get episode schedule working

This commit is contained in:
Marek Wolan
2024-04-23 11:51:50 +01:00
parent 2b3664ce36
commit 28c8b7c9d9
9 changed files with 198 additions and 25 deletions

View File

@@ -1,5 +1,5 @@
greens: &greens
- ref: green_client_2
- ref: green_A
team: GREEN
type: ProbabilisticAgent
agent_settings:
@@ -48,7 +48,7 @@ greens: &greens
options:
node_hostname: client_2
- ref: green_client_1
- ref: green_B
team: GREEN
type: ProbabilisticAgent
agent_settings:

View File

@@ -1,5 +1,5 @@
greens: &greens
- ref: green_client_2
- ref: green_C
team: GREEN
type: ProbabilisticAgent
agent_settings:

View File

@@ -1,5 +1,5 @@
reds: &reds
- ref: attacker_1
- ref: red_A
team: RED
type: RedDatabaseCorruptingAgent

View File

@@ -1,5 +1,5 @@
reds: &reds
- ref: attacker_2
- ref: red_B
team: RED
type: RedDatabaseCorruptingAgent

View File

@@ -23,7 +23,6 @@ game:
agents:
- *greens
- *reds
- *blue(s)
- ref: defender
team: BLUE

View File

@@ -1,17 +1,17 @@
base_scenario: scenario.yaml
schedule:
0:
green: greens_1.yaml
red: reds_1.yaml
- greens_1.yaml
- reds_1.yaml
1:
green: greens_1.yaml
red: reds_2.yaml
- greens_1.yaml
- reds_2.yaml
2:
green: greens_2.yaml
red: reds_1.yaml
- greens_2.yaml
- reds_1.yaml
3:
green: greens_2.yaml
red: reds_2.yaml
- greens_2.yaml
- reds_2.yaml
# touch base with container to see what they've implemented for training schedule and evaluation schedule - for naming convention consistency
# when you exceed the number of episodes defined in the yaml, raise a warning and loop back to the beginning

View File

@@ -110,6 +110,54 @@
"print(list(gym.game.agents.keys()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from primaite.session.environment import PrimaiteGymEnv"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env_2 = PrimaiteGymEnv(game_config='/home/cade/repos/PrimAITE/src/primaite/config/_package_data/scenario_with_placeholders')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(10):\n",
" print(env_2.episode_counter)\n",
" print(list(env_2.game.agents.keys()))\n",
" env_2.reset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"env = PrimaiteGymEnv(game_config='/home/cade/repos/PrimAITE/src/primaite/config/_package_data/data_manipulation.yaml')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sum([[1,2],[3,4]])"
]
},
{
"cell_type": "code",
"execution_count": null,

View File

@@ -1,6 +1,7 @@
import copy
import json
from typing import Any, Dict, Optional, SupportsFloat, Tuple
from os import PathLike
from typing import Any, Dict, Optional, SupportsFloat, Tuple, Union
import gymnasium
from gymnasium.core import ActType, ObsType
@@ -9,6 +10,7 @@ from ray.rllib.env.multi_agent_env import MultiAgentEnv
from primaite import getLogger
from primaite.game.agent.interface import ProxyAgent
from primaite.game.game import PrimaiteGame
from primaite.session.episode_schedule import build_scheduler, EpisodeScheduler
from primaite.session.io import PrimaiteIO
from primaite.simulator import SIM_OUTPUT
@@ -23,17 +25,14 @@ class PrimaiteGymEnv(gymnasium.Env):
assumptions about the agent list always having a list of length 1.
"""
def __init__(self, game_config: Dict):
def __init__(self, game_config: Union[Dict, str, PathLike]):
"""Initialise the environment."""
super().__init__()
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
self.episode_scheduler: EpisodeScheduler = build_scheduler(game_config)
"""Object that returns a config corresponding to the current episode."""
self.io = PrimaiteIO.from_config(self.episode_scheduler(0).get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game_config: Dict = game_config
"""PrimaiteGame definition. This can be changed between episodes to enable curriculum learning."""
self.io = PrimaiteIO.from_config(game_config.get("io_settings", {}))
"""Handles IO for the environment. This produces sys logs, agent logs, etc."""
self.game: PrimaiteGame = PrimaiteGame.from_config(copy.deepcopy(self.game_config))
self.game: PrimaiteGame = PrimaiteGame.from_config(self.episode_scheduler(0))
"""Current game."""
self._agent_name = next(iter(self.game.rl_agents))
"""Name of the RL agent. Since there should only be one RL agent we can just pull the first and only key."""
@@ -94,9 +93,9 @@ class PrimaiteGymEnv(gymnasium.Env):
if self.io.settings.save_agent_actions:
all_agent_actions = {name: agent.action_history for name, agent in self.game.agents.items()}
self.io.write_agent_actions(agent_actions=all_agent_actions, episode=self.episode_counter)
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=copy.deepcopy(self.game_config))
self.game.setup_for_episode(episode=self.episode_counter)
self.episode_counter += 1
self.game: PrimaiteGame = PrimaiteGame.from_config(cfg=self.episode_scheduler(self.episode_counter))
self.game.setup_for_episode(episode=self.episode_counter)
state = self.game.get_sim_state()
self.game.update_agents(state=state)
next_obs = self._get_obs()

View File

@@ -0,0 +1,127 @@
import copy
from abc import ABC, abstractmethod
from os import PathLike
from pathlib import Path
from typing import Dict, List, Mapping, Sequence, Union
import pydantic
from primaite import getLogger
_LOGGER = getLogger(__name__)
import warnings
from itertools import chain
import yaml
class EpisodeScheduler(pydantic.BaseModel, ABC):
"""
Episode schedulers provide functionality to select different scenarios and game setups for each episode.
This is useful when implementing advanced RL concepts like curriculum learning and domain randomisation.
"""
@abstractmethod
def __call__(self, episode_num: int) -> Dict:
"""Return the config that should be used during this episode."""
class ConstantEpisodeScheduler(EpisodeScheduler):
"""
The constant episode schedule simply provides the same game setup every time.
"""
config: Dict
def __call__(self, episode_num: int) -> Dict:
"""Return the same config every time."""
return copy.deepcopy(self.config)
class EpisodeListScheduler(EpisodeScheduler):
"""The episode list u"""
schedule: Mapping[int, List[str]]
"""Mapping from episode number to list of filenames"""
episode_data: Mapping[str, str]
"""Mapping from filename to yaml string."""
base_scenario: str
"""yaml string containing the base scenario."""
_exceeded_episode_list: bool = False
"""
Flag that's set to true when attempting to keep generating episodes after schedule runs out.
When this happens, we loop back to the beginning, but a warning is raised.
"""
# TODO: be careful about off-by-one errors with episode number- should it start at 0 or 1?
def __call__(self, episode_num: int) -> Dict:
if episode_num > len(self.schedule):
if not self._exceeded_episode_list:
self._exceeded_episode_list = True
_LOGGER.warn(
f"Running episode {episode_num} but the schedule only defines"
f"{len(self.schedule)} episodes. Looping back to the beginning"
)
# not sure if we should be using a traditional warning, or a _LOGGER.warning
episode_num = episode_num % len(self.schedule)
filenames_to_join = self.schedule[episode_num]
yaml_data_to_join = [self.episode_data[fn] for fn in filenames_to_join] + [self.base_scenario]
joined_yaml = "\n".join(yaml_data_to_join)
parsed_cfg = yaml.safe_load(joined_yaml)
# Unfortunately, using placeholders like this is slightly hacky, so we have to flatten the list of agents
flat_agents_list = []
for a in parsed_cfg["agents"]:
if isinstance(a, Sequence):
flat_agents_list.extend(a)
else:
flat_agents_list.append(a)
parsed_cfg["agents"] = flat_agents_list
return parsed_cfg
def build_scheduler(config: Union[str, Path, Dict]) -> EpisodeScheduler:
"""
Convenience method to build an EpisodeScheduler with a dict, file path, or folder path.
If a path to a folder is provided, it will be treated as a list of game scenarios.
Otherwise, if a dict or a single file is provided, it will be treated as a constant game scenario.
"""
# If we get a dict, return a constant episode schedule that repeats that one config forever
if isinstance(config, Dict):
return ConstantEpisodeScheduler(config=config)
# Cast string to Path
if isinstance(config, str):
config = Path(config)
if not config.exists():
raise FileNotFoundError(f"Provided config path {config} could not be found.")
if config.is_file():
with open(config, "r") as f:
cfg_data = yaml.safe_load(f)
return ConstantEpisodeScheduler(config=cfg_data)
if not config.is_dir():
raise RuntimeError("Something went wrong while building Primaite config.")
root = config
schedule_path = root / "schedule.yaml"
with open(schedule_path, "r") as f:
schedule = yaml.safe_load(f)
base_scenario_path = root / schedule["base_scenario"]
files_to_load = set(chain.from_iterable(schedule["schedule"].values()))
episode_data = {fp: (root / fp).read_text() for fp in files_to_load}
return EpisodeListScheduler(
schedule=schedule["schedule"], episode_data=episode_data, base_scenario=base_scenario_path.read_text()
)