Merge branch 'dev' into bugfix/2453-Example_Notebooks_require_refactor

This commit is contained in:
Charlie Crane
2024-04-18 09:30:12 +01:00
37 changed files with 171 additions and 1734 deletions

View File

@@ -1 +1 @@
3.0.0b7
3.0.0b9dev

View File

@@ -114,23 +114,3 @@ def setup(overwrite_existing: bool = True) -> None:
reset_example_configs.run(overwrite_existing=True)
_LOGGER.info("PrimAITE setup complete!")
@app.command()
def session(
config: Optional[str] = None,
agent_load_file: Optional[str] = None,
) -> None:
"""
Run a PrimAITE session.
:param config: The path to the config file. Optional, if None, the example config will be used.
:type config: Optional[str]
"""
from primaite.config.load import data_manipulation_config_path
from primaite.main import run
if not config:
config = data_manipulation_config_path()
print(config)
run(config_path=config, agent_load_path=agent_load_file)

View File

@@ -1,17 +1,3 @@
training_config:
rl_framework: RLLIB_multi_agent
rl_algorithm: PPO
seed: 333
n_learn_episodes: 1
n_eval_episodes: 5
max_steps_per_episode: 128
deterministic_eval: false
n_agents: 2
agent_references:
- defender_1
- defender_2
io_settings:
save_agent_actions: true
save_step_metadata: false
@@ -1472,7 +1458,7 @@ simulation:
options:
db_server_ip: 192.168.1.14
services:
- ty DNSClient
- type: DNSClient

View File

@@ -210,8 +210,8 @@ class PrimaiteGame:
"""Create a PrimaiteGame object from a config dictionary.
The config dictionary should have the following top-level keys:
1. training_config: options for training the RL agent.
2. game_config: options for the game itself. Used by PrimaiteGame.
1. io_settings: options for logging data during training
2. game_config: options for the game itself, such as agents.
3. simulation: defines the network topology and the initial state of the simulation.
The specification for each of the three major areas is described in a separate documentation page.

View File

@@ -1,47 +0,0 @@
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
"""The main PrimAITE session runner module."""
import argparse
from pathlib import Path
from typing import Optional, Union
from primaite import getLogger
from primaite.config.load import data_manipulation_config_path, load
from primaite.session.session import PrimaiteSession
# from primaite.primaite_session import PrimaiteSession
_LOGGER = getLogger(__name__)
def run(
config_path: Optional[Union[str, Path]] = "",
agent_load_path: Optional[Union[str, Path]] = None,
) -> None:
"""
Run the PrimAITE Session.
:param training_config_path: YAML file containing configurable items defined in
`primaite.config.training_config.TrainingConfig`
:type training_config_path: Union[path, str]
:param lay_down_config_path: YAML file containing configurable items for generating network laydown.
:type lay_down_config_path: Union[path, str]
:param session_path: directory path of the session to load
:param legacy_training_config: True if the training config file is a legacy file from PrimAITE < 2.0,
otherwise False.
:param legacy_lay_down_config: True if the lay_down config file is a legacy file from PrimAITE < 2.0,
otherwise False.
"""
cfg = load(config_path)
sess = PrimaiteSession.from_config(cfg=cfg, agent_load_path=agent_load_path)
sess.start_session()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config")
args = parser.parse_args()
if not args.config:
args.config = data_manipulation_config_path()
run(args.config)

View File

@@ -48,7 +48,7 @@ class PrimaiteGymEnv(gymnasium.Env):
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""Perform a step in the environment."""
# make ProxyAgent store the action chosen my the RL policy
# make ProxyAgent store the action chosen by the RL policy
step = self.game.step_counter
self.agent.store_action(action)
# apply_agent_actions accesses the action we just stored

View File

@@ -95,6 +95,7 @@ class PrimaiteIO:
@classmethod
def from_config(cls, config: Dict) -> "PrimaiteIO":
"""Create an instance of PrimaiteIO based on a configuration dict."""
config = config or {}
new = cls(settings=cls.Settings(**config))
return new

View File

@@ -1,4 +0,0 @@
from primaite.session.policy.rllib import RaySingleAgentPolicy
from primaite.session.policy.sb3 import SB3Policy
__all__ = ["SB3Policy", "RaySingleAgentPolicy"]

View File

@@ -1,82 +0,0 @@
"""Base class and common logic for RL policies."""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, Type, TYPE_CHECKING
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
class PolicyABC(ABC):
"""Base class for reinforcement learning agents."""
_registry: Dict[str, Type["PolicyABC"]] = {}
"""
Registry of policy types, keyed by name.
Automatically populated when PolicyABC subclasses are defined. Used for defining from_config.
"""
def __init_subclass__(cls, identifier: str, **kwargs: Any) -> None:
"""
Register a policy subclass.
:param name: Identifier used by from_config to create an instance of the policy.
:type name: str
:raises ValueError: When attempting to create a policy with a duplicate name.
"""
super().__init_subclass__(**kwargs)
if identifier in cls._registry:
raise ValueError(f"Duplicate policy name {identifier}")
cls._registry[identifier] = cls
return
@abstractmethod
def __init__(self, session: "PrimaiteSession") -> None:
"""
Initialize a reinforcement learning policy.
:param session: The session context.
:type session: PrimaiteSession
:param agents: The agents to train.
:type agents: List[RLAgent]
"""
self.session: "PrimaiteSession" = session
"""Reference to the session."""
@abstractmethod
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
pass
@abstractmethod
def eval(self, n_episodes: int, timesteps_per_episode: int, deterministic: bool) -> None:
"""Evaluate the agent."""
pass
@abstractmethod
def save(self, save_path: Path) -> None:
"""Save the agent."""
pass
@abstractmethod
def load(self) -> None:
"""Load agent from a file."""
pass
def close(self) -> None:
"""Close the agent."""
pass
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "PolicyABC":
"""
Create an RL policy from a config by calling the relevant subclass's from_config method.
Subclasses should not call super().from_config(), they should just handle creation form config.
"""
# Assume that basically the contents of training_config are passed into here.
# I should really define a config schema class using pydantic.
PolicyType = cls._registry[config.rl_framework]
return PolicyType.from_config(config=config, session=session)

View File

@@ -1,111 +0,0 @@
from pathlib import Path
from typing import Literal, Optional, TYPE_CHECKING
from primaite.session.environment import PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
import ray
from ray import air, tune
from ray.rllib.algorithms import ppo
from ray.rllib.algorithms.ppo import PPOConfig
from primaite import getLogger
_LOGGER = getLogger(__name__)
class RaySingleAgentPolicy(PolicyABC, identifier="RLLIB_single_agent"):
"""Single agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
super().__init__(session=session)
self.config = {
"env": PrimaiteRayEnv,
"env_config": {"game": session.game},
"disable_env_checking": True,
"num_rollout_workers": 0,
}
ray.shutdown()
ray.init()
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
self.config["training_iterations"] = n_episodes * timesteps_per_episode
self.config["train_batch_size"] = 128
self._algo = ppo.PPO(config=self.config)
_LOGGER.info("Starting RLLIB training session")
self._algo.train()
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
for ep in range(n_episodes):
obs, info = self.session.env.reset()
for step in range(self.session.game.options.max_episode_length):
action = self._algo.compute_single_action(observation=obs, explore=False)
obs, rew, term, trunc, info = self.session.env.step(action)
def save(self, save_path: Path) -> None:
"""Save the policy to a file."""
self._algo.save(save_path)
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
raise NotImplementedError
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RaySingleAgentPolicy":
"""Create a policy from a config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)
class RayMultiAgentPolicy(PolicyABC, identifier="RLLIB_multi_agent"):
"""Mutli agent RL policy using Ray RLLib."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO"], seed: Optional[int] = None):
"""Initialise multi agent policy wrapper."""
super().__init__(session=session)
self.config = (
PPOConfig()
.environment(env=PrimaiteRayMARLEnv, env_config={"game": session.game})
.rollouts(num_rollout_workers=0)
.multi_agent(
policies={agent.agent_name for agent in session.game.rl_agents},
policy_mapping_fn=lambda agent_id, episode, worker, **kw: agent_id,
)
.training(train_batch_size=128)
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
checkpoint_freq = self.session.io_manager.settings.checkpoint_interval
tune.Tuner(
"PPO",
run_config=air.RunConfig(
stop={"training_iteration": n_episodes * timesteps_per_episode},
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=checkpoint_freq),
),
param_space=self.config,
).fit()
def load(self, model_path: Path) -> None:
"""Load policy parameters from a file."""
return NotImplemented
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate trained policy."""
return NotImplemented
def save(self, save_path: Path) -> None:
"""Save policy parameters to a file."""
return NotImplemented
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "RayMultiAgentPolicy":
"""Create policy from config."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,79 +0,0 @@
"""Stable baselines 3 policy."""
from pathlib import Path
from typing import Literal, Optional, Type, TYPE_CHECKING, Union
from stable_baselines3 import A2C, PPO
from stable_baselines3.a2c import MlpPolicy as A2C_MLP
from stable_baselines3.common.callbacks import CheckpointCallback
from stable_baselines3.common.evaluation import evaluate_policy
from stable_baselines3.ppo import MlpPolicy as PPO_MLP
from primaite.session.policy.policy import PolicyABC
if TYPE_CHECKING:
from primaite.session.session import PrimaiteSession, TrainingOptions
class SB3Policy(PolicyABC, identifier="SB3"):
"""Single agent RL policy using stable baselines 3."""
def __init__(self, session: "PrimaiteSession", algorithm: Literal["PPO", "A2C"], seed: Optional[int] = None):
"""Initialize a stable baselines 3 policy."""
super().__init__(session=session)
self._agent_class: Type[Union[PPO, A2C]]
if algorithm == "PPO":
self._agent_class = PPO
policy = PPO_MLP
elif algorithm == "A2C":
self._agent_class = A2C
policy = A2C_MLP
else:
raise ValueError(f"Unknown algorithm `{algorithm}` for stable_baselines3 policy")
self._agent = self._agent_class(
policy=policy,
env=self.session.env,
n_steps=128, # this is not the number of steps in an episode, but the number of steps in a batch
seed=seed,
)
def learn(self, n_episodes: int, timesteps_per_episode: int) -> None:
"""Train the agent."""
if self.session.save_checkpoints:
checkpoint_callback = CheckpointCallback(
save_freq=timesteps_per_episode * self.session.checkpoint_interval,
save_path=self.session.io_manager.generate_model_save_path("sb3"),
name_prefix="sb3_model",
)
else:
checkpoint_callback = None
self._agent.learn(total_timesteps=n_episodes * timesteps_per_episode, callback=checkpoint_callback)
def eval(self, n_episodes: int, deterministic: bool) -> None:
"""Evaluate the agent."""
_ = evaluate_policy(
self._agent,
self.session.env,
n_eval_episodes=n_episodes,
deterministic=deterministic,
return_episode_rewards=True,
)
def save(self, save_path: Path) -> None:
"""
Save the current policy parameters.
Warning: The recommended way to save model checkpoints is to use a callback within the `learn()` method. Please
refer to https://stable-baselines3.readthedocs.io/en/master/guide/callbacks.html for more information.
Therefore, this method is only used to save the final model.
"""
self._agent.save(save_path)
def load(self, model_path: Path) -> None:
"""Load agent from a checkpoint."""
self._agent = self._agent_class.load(model_path, env=self.session.env)
@classmethod
def from_config(cls, config: "TrainingOptions", session: "PrimaiteSession") -> "SB3Policy":
"""Create an agent from config file."""
return cls(session=session, algorithm=config.rl_algorithm, seed=config.seed)

View File

@@ -1,119 +0,0 @@
# raise DeprecationWarning("This module is deprecated")
from enum import Enum
from pathlib import Path
from typing import Dict, List, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict
from primaite.session.environment import PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv
from primaite.session.io import PrimaiteIO
# from primaite.game.game import PrimaiteGame
from primaite.session.policy.policy import PolicyABC
class TrainingOptions(BaseModel):
"""Options for training the RL agent."""
model_config = ConfigDict(extra="forbid")
rl_framework: Literal["SB3", "RLLIB_single_agent", "RLLIB_multi_agent"]
rl_algorithm: Literal["PPO", "A2C"]
n_learn_episodes: int
n_eval_episodes: Optional[int] = None
max_steps_per_episode: int
# checkpoint_freq: Optional[int] = None
deterministic_eval: bool
seed: Optional[int]
n_agents: int
agent_references: List[str]
class SessionMode(Enum):
"""Helper to keep track of the current session mode."""
TRAIN = "train"
EVAL = "eval"
MANUAL = "manual"
class PrimaiteSession:
"""The main entrypoint for PrimAITE sessions, this manages a simulation, policy training, and environments."""
def __init__(self, game_cfg: Dict):
"""Initialise PrimaiteSession object."""
self.training_options: TrainingOptions
"""Options specific to agent training."""
self.mode: SessionMode = SessionMode.MANUAL
"""Current session mode."""
self.env: Union[PrimaiteGymEnv, PrimaiteRayEnv, PrimaiteRayMARLEnv]
"""The environment that the RL algorithm can consume."""
self.policy: PolicyABC
"""The reinforcement learning policy."""
self.io_manager: Optional["PrimaiteIO"] = None
"""IO manager for the session."""
self.game_cfg: Dict = game_cfg
"""Primaite Game object for managing main simulation loop and agents."""
self.save_checkpoints: bool = False
"""Whether to save checkpoints."""
self.checkpoint_interval: int = 10
"""If save_checkpoints is true, checkpoints will be saved every checkpoint_interval episodes."""
def start_session(self) -> None:
"""Commence the training/eval session."""
print("Starting Primaite Session")
self.mode = SessionMode.TRAIN
n_learn_episodes = self.training_options.n_learn_episodes
n_eval_episodes = self.training_options.n_eval_episodes
max_steps_per_episode = self.training_options.max_steps_per_episode
deterministic_eval = self.training_options.deterministic_eval
self.policy.learn(
n_episodes=n_learn_episodes,
timesteps_per_episode=max_steps_per_episode,
)
self.save_models()
self.mode = SessionMode.EVAL
if n_eval_episodes > 0:
self.policy.eval(n_episodes=n_eval_episodes, deterministic=deterministic_eval)
self.mode = SessionMode.MANUAL
def save_models(self) -> None:
"""Save the RL models."""
save_path = self.io_manager.generate_model_save_path("temp_model_name")
self.policy.save(save_path)
@classmethod
def from_config(cls, cfg: Dict, agent_load_path: Optional[str] = None) -> "PrimaiteSession":
"""Create a PrimaiteSession object from a config dictionary."""
# READ IO SETTINGS (this sets the global session path as well) # TODO: GLOBAL SIDE EFFECTS...
io_manager = PrimaiteIO.from_config(cfg.get("io_settings", {}))
sess = cls(game_cfg=cfg)
sess.io_manager = io_manager
sess.training_options = TrainingOptions(**cfg["training_config"])
sess.save_checkpoints = cfg.get("io_settings", {}).get("save_checkpoints")
sess.checkpoint_interval = cfg.get("io_settings", {}).get("checkpoint_interval")
# CREATE ENVIRONMENT
if sess.training_options.rl_framework == "RLLIB_single_agent":
sess.env = PrimaiteRayEnv(env_config=cfg)
elif sess.training_options.rl_framework == "RLLIB_multi_agent":
sess.env = PrimaiteRayMARLEnv(env_config=cfg)
elif sess.training_options.rl_framework == "SB3":
sess.env = PrimaiteGymEnv(game_config=cfg)
sess.policy = PolicyABC.from_config(sess.training_options, session=sess)
if agent_load_path:
sess.policy.load(Path(agent_load_path))
return sess

View File

@@ -271,7 +271,7 @@ class IPWirelessNetworkInterface(WirelessNetworkInterface, Layer3Interface, ABC)
# Update the state with information from Layer3Interface
state.update(Layer3Interface.describe_state(self))
state["frequency"] = self.frequency
state["frequency"] = self.frequency.value
return state

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import json
from pathlib import Path

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
from pathlib import Path
from typing import Any, Dict, Tuple, Union

View File

@@ -1,3 +1,7 @@
# flake8: noqa
raise DeprecationWarning(
"Benchmarking depends on deprecated functionality and it has not been updated to primaite v3 yet."
)
# © Crown-owned copyright 2023, Defence Science and Technology Laboratory UK
import csv
from logging import Logger