#917 - Got RLlib fully training in PrimAITE. Started integrating the the other agents into the Session class

This commit is contained in:
Chris McCarthy
2023-06-18 22:40:56 +01:00
parent 31eb36c75a
commit c09874edbe
9 changed files with 274 additions and 229 deletions

View File

@@ -1 +1 @@
2.0.0dev0
2.0.0rc1

View File

@@ -1,36 +1,84 @@
from abc import ABC, abstractmethod
from typing import Optional, Final, Dict, Any
from datetime import datetime
from pathlib import Path
from typing import Optional, Final, Dict, Any, Union, Tuple
import yaml
from primaite import getLogger
from primaite.config.training_config import TrainingConfig
from primaite.config.training_config import TrainingConfig, load
from primaite.environment.primaite_env import Primaite
_LOGGER = getLogger(__name__)
def _get_temp_session_path(session_timestamp: datetime) -> Path:
"""
Get a temp directory session path the test session will output to.
:param session_timestamp: This is the datetime that the session started.
:return: The session directory path.
"""
date_dir = session_timestamp.strftime("%Y-%m-%d")
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
session_path = Path("./") / date_dir / session_dir
session_path.mkdir(exist_ok=True, parents=True)
return session_path
class AgentABC(ABC):
@abstractmethod
def __init__(self, env: Primaite):
self._env: Primaite = env
self._training_config: Final[TrainingConfig] = self._env.training_config
self._lay_down_config: Dict[str, Any] = self._env.lay_down_config
def __init__(
self,
training_config_path,
lay_down_config_path
):
self._training_config_path = training_config_path
self._training_config: Final[TrainingConfig] = load(
self._training_config_path
)
self._lay_down_config_path = lay_down_config_path
self._env: Primaite
self._agent = None
self.session_timestamp: datetime = datetime.now()
self.session_path = _get_temp_session_path(self.session_timestamp)
self.timestamp_str = self.session_timestamp.strftime(
"%Y-%m-%d_%H-%M-%S")
@abstractmethod
def _setup(self):
pass
@abstractmethod
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
def _save_checkpoint(self):
pass
@abstractmethod
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
pass
@abstractmethod
def load(self):
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
pass
@abstractmethod
def _get_latest_checkpoint(self):
pass
@classmethod
@abstractmethod
def load(cls):
pass
@abstractmethod
@@ -44,14 +92,24 @@ class AgentABC(ABC):
class DeterministicAgentABC(AgentABC):
@abstractmethod
def __init__(self, env: Primaite):
self._env: Primaite = env
def __init__(
self,
training_config_path,
lay_down_config_path
):
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
self._env: Primaite
self._agent = None
@abstractmethod
def _setup(self):
pass
@abstractmethod
def _get_latest_checkpoint(self):
pass
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
pass
_LOGGER.warning("Deterministic agents cannot learn")
@@ -60,8 +118,9 @@ class DeterministicAgentABC(AgentABC):
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
pass
@classmethod
@abstractmethod
def load(self):
def load(cls):
pass
@abstractmethod

View File

@@ -8,170 +8,106 @@ from ray.rllib.algorithms import Algorithm
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from primaite.agents.agent_abc import AgentABC
from primaite.config import training_config
from primaite.environment.primaite_env import Primaite
class DLFramework(Enum):
"""The DL Frameworks enumeration."""
TF = "tf"
TF2 = "tf2"
TORCH = "torch"
def _env_creator(env_config):
return Primaite(
training_config_path=env_config["training_config_path"],
lay_down_config_path=env_config["lay_down_config_path"],
transaction_list=env_config["transaction_list"],
session_path=env_config["session_path"],
timestamp_str=env_config["timestamp_str"]
)
def env_creator(env_config):
training_config_path = env_config["training_config_path"]
lay_down_config_path = env_config["lay_down_config_path"]
return Primaite(training_config_path, lay_down_config_path, [])
class RLlibPPO(AgentABC):
def __init__(
self,
training_config_path,
lay_down_config_path
):
super().__init__(training_config_path, lay_down_config_path)
self._ppo_config: PPOConfig
self._current_result: dict
self._setup()
def get_ppo_config(
training_config_path: Union[str, Path],
lay_down_config_path: Union[str, Path],
framework: Optional[DLFramework] = DLFramework.TORCH
) -> PPOConfig():
# Register environment
register_env("primaite", env_creator)
def _setup(self):
register_env("primaite", _env_creator)
self._ppo_config = PPOConfig()
# Setup PPO
config = PPOConfig()
config_values = training_config.load(training_config_path)
# Setup our config object to use our environment
config.environment(
env="primaite",
env_config=dict(
training_config_path=training_config_path,
lay_down_config_path=lay_down_config_path
self._ppo_config.environment(
env="primaite",
env_config=dict(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=[],
session_path=self.session_path,
timestamp_str=self.timestamp_str
)
)
)
env_config = config_values
action_type = env_config.action_type
red_agent = env_config.red_agent_identifier
self._ppo_config.training(
train_batch_size=self._training_config.num_steps
)
self._ppo_config.framework(
framework=self._training_config.deep_learning_framework.value
)
if red_agent == "RANDOM" and action_type == "NODE":
config.training(
train_batch_size=6000, lr=5e-5
) # number of steps in a training iteration
elif red_agent == "RANDOM" and action_type != "NODE":
config.training(train_batch_size=6000, lr=5e-5)
elif red_agent == "CONFIG" and action_type == "NODE":
config.training(train_batch_size=400, lr=5e-5)
elif red_agent == "CONFIG" and action_type != "NONE":
config.training(train_batch_size=500, lr=5e-5)
else:
config.training(train_batch_size=500, lr=5e-5)
self._ppo_config.rollouts(
num_rollout_workers=1,
num_envs_per_worker=1,
horizon=self._training_config.num_steps
)
self._agent: Algorithm = self._ppo_config.build()
# Decide if you want torch or tensorflow DL framework. Default is "tf"
config.framework(framework=framework.value)
def _save_checkpoint(self):
checkpoint_n = self._training_config.checkpoint_every_n_episodes
episode_count = self._current_result["episodes_total"]
if checkpoint_n > 0 and episode_count > 0:
if (
(episode_count % checkpoint_n == 0)
or (episode_count == self._training_config.num_episodes)
):
self._agent.save(self.session_path)
# Set the log level to DEBUG, INFO, WARN, or ERROR
config.debugging(seed=415, log_level="ERROR")
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
# Temporarily override train_batch_size and horizon
if time_steps:
self._ppo_config.train_batch_size = time_steps
self._ppo_config.horizon = time_steps
# Setup evaluation
# Explicitly set "explore"=False to override default
# config.evaluation(
# evaluation_interval=100,
# evaluation_duration=20,
# # evaluation_duration_unit="timesteps",) #default episodes
# evaluation_config={"explore": False},
# )
if not episodes:
episodes = self._training_config.num_episodes
# Setup sampling rollout workers
config.rollouts(
num_rollout_workers=4,
num_envs_per_worker=1,
horizon=128, # num parralel workiers
) # max num steps in an episode
for i in range(episodes):
self._current_result = self._agent.train()
self._save_checkpoint()
self._agent.stop()
config.build() # Build config
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
raise NotImplementedError
return config
def _get_latest_checkpoint(self):
raise NotImplementedError
@classmethod
def load(cls):
raise NotImplementedError
def train(
num_iterations: int,
config: Optional[PPOConfig] = None,
algo: Optional[Algorithm] = None
):
"""
def save(self):
raise NotImplementedError
Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint)
"""
start_time = time.time()
if algo is None:
algo = config.build()
elif config is None:
config = algo.get_config()
print(f"Algorithm type: {type(algo)}")
# iterations are not the same as episodes.
for i in range(num_iterations):
result = algo.train()
# # Save every 10 iterations or after last iteration in training
# if (i % 100 == 0) or (i == num_iterations - 1):
print(
f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}")
# save checkpoint file
checkpoint_file = algo.save("./")
print(f"Checkpoint saved at {checkpoint_file}")
# convert num_iterations to num_episodes
num_episodes = len(
result["hist_stats"]["episode_lengths"]) * num_iterations
# convert num_iterations to num_timesteps
num_timesteps = sum(
result["hist_stats"]["episode_lengths"] * num_iterations)
# calculate number of wins
# train time
print(f"Training took {time.time() - start_time:.2f} seconds")
print(
f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}")
return result
def load_model_from_checkpoint(config, checkpoint=None):
# create an empty Algorithm
algo = config.build()
if checkpoint is None:
# Get the checkpoint with the highest iteration number
checkpoint = get_most_recent_checkpoint(config)
# restore the agent from the checkpoint
algo.restore(checkpoint)
return algo
def get_most_recent_checkpoint(config):
"""
Get the most recent checkpoint for specified action type, red agent and algorithm
"""
env_config = list(config.env_config.values())[0]
action_type = env_config.action_type
red_agent = env_config.red_agent_identifier
algo_name = config.algo_class.__name__
# Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model
relevant_checkpoints = glob.glob(
f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*"
)
checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints]
max_checkpoint = str(max(checkpoint_numbers))
checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint
checkpoint = (
relevant_checkpoints[0].split("_")[0]
+ "_"
+ checkpoint_number_to_use
+ "/rllib_checkpoint.json"
)
return checkpoint
def export(self):
raise NotImplementedError

View File

@@ -6,30 +6,74 @@ from primaite.agents.agent_abc import AgentABC
from primaite.environment.primaite_env import Primaite
from stable_baselines3.ppo import MlpPolicy as PPOMlp
class SB3PPO(AgentABC):
def __init__(self, env: Primaite):
super().__init__(env)
def __init__(
self,
training_config_path,
lay_down_config_path
):
super().__init__(training_config_path, lay_down_config_path)
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
def _setup(self):
self._env = Primaite(
training_config_path=self._training_config_path,
lay_down_config_path=self._lay_down_config_path,
transaction_list=[],
session_path=self.session_path,
timestamp_str=self.timestamp_str
)
self._agent = PPO(
PPOMlp,
self._env,
verbose=0,
n_steps=self._training_config.num_steps
n_steps=self._training_config.num_steps,
tensorboard_log=self._tensorboard_log_path
)
def learn(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None
):
if not time_steps:
time_steps = self._training_config.num_steps
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
pass
if not episodes:
episodes = self._training_config.num_episodes
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
pass
for i in range(episodes):
self._agent.learn(total_timesteps=time_steps)
def evaluate(
self,
time_steps: Optional[int] = None,
episodes: Optional[int] = None,
deterministic: bool = True
):
if not time_steps:
time_steps = self._training_config.num_steps
if not episodes:
episodes = self._training_config.num_episodes
for episode in range(episodes):
obs = self._env.reset()
for step in range(time_steps):
action, _states = self._agent.predict(
obs,
deterministic=deterministic
)
obs, rewards, done, info = self._env.step(action)
def load(self):
pass
raise NotImplementedError
def save(self):
pass
raise NotImplementedError
def export(self):
pass
raise NotImplementedError

View File

@@ -95,15 +95,32 @@ class VerboseLevel(Enum):
class AgentFramework(Enum):
NONE = 0
"Custom Agent"
SB3 = 1
"Stable Baselines3"
RLLIB = 2
"Ray RLlib"
class DeepLearningFramework(Enum):
"""The deep learning framework enumeration."""
TF = "tf"
"Tensorflow"
TF2 = "tf2"
"Tensorflow 2.x"
TORCH = "torch"
"PyTorch"
class RedAgentIdentifier(Enum):
A2C = 1
"Advantage Actor Critic"
PPO = 2
"Proximal Policy Optimization"
HARDCODED = 3
"Custom Agent"
RANDOM = 4
"Custom Agent"
class ActionType(Enum):

View File

@@ -1,26 +1,52 @@
# Main Config File
# Generic config values
# Choose one of these (dependent on Agent being trained)
# "STABLE_BASELINES3_PPO"
# "STABLE_BASELINES3_A2C"
# "GENERIC"
agent_identifier: STABLE_BASELINES3_A2C
# Sets which agent algorithm framework will be used:
# Options are:
# "SB3" (Stable Baselines3)
# "RLLIB" (Ray RLlib)
# "NONE" (Custom Agent)
agent_framework: RLLIB
# Sets which deep learning framework will be used. Default is TF (Tensorflow).
# Options are:
# "TF" (Tensorflow)
# TF2 (Tensorflow 2.X)
# TORCH (PyTorch)
deep_learning_framework: TORCH
# Sets which Red Agent algo/class will be used:
# Options are:
# "A2C" (Advantage Actor Critic)
# "PPO" (Proximal Policy Optimization)
# "HARDCODED" (Custom Agent)
# "RANDOM" (Random Action)
red_agent_identifier: PPO
# Sets How the Action Space is defined:
# "NODE"
# "ACL"
# "ANY" node and acl actions
action_type: NODE
# Number of episodes to run per session
num_episodes: 10
# Number of time_steps per episode
num_steps: 256
# Sets how often the agent will save a checkpoint (every n time episodes).
# Set to 0 if no checkpoints are required. Default is 10
checkpoint_every_n_episodes: 5
# Time delay between steps (for generic agents)
time_delay: 10
# Type of session to be run (TRAINING or EVALUATION)
session_type: TRAINING
# Determine whether to load an agent from file
load_agent: False
# File path and file name of agent if you're loading one in
agent_load_file: C:\[Path]\[agent_saved_filename.zip]

View File

@@ -8,6 +8,7 @@ from typing import Any, Dict, Final, Union, Optional
import yaml
from primaite import USERS_CONFIG_DIR, getLogger
from primaite.common.enums import DeepLearningFramework
from primaite.common.enums import ActionType, RedAgentIdentifier, \
AgentFramework, SessionType
@@ -20,10 +21,13 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
class TrainingConfig:
"""The Training Config class."""
agent_framework: AgentFramework = AgentFramework.SB3
"The agent framework."
"The AgentFramework"
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
"The DeepLearningFramework."
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
"The red agent/algo class."
"The RedAgentIdentifier.."
action_type: ActionType = ActionType.ANY
"The ActionType to use."
@@ -33,6 +37,10 @@ class TrainingConfig:
num_steps: int = 256
"The number of steps in an episode."
checkpoint_every_n_episodes: int = 5
"The agent will save a checkpoint every n episodes."
observation_space: dict = field(
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
)
@@ -148,6 +156,7 @@ class TrainingConfig:
) -> TrainingConfig:
field_enum_map = {
"agent_framework": AgentFramework,
"deep_learning_framework": DeepLearningFramework,
"red_agent_identifier": RedAgentIdentifier,
"action_type": ActionType,
"session_type": SessionType
@@ -155,7 +164,7 @@ class TrainingConfig:
for field, enum_class in field_enum_map.items():
if field in config_dict:
config_dict[field] = enum_class[field]
config_dict[field] = enum_class[config_dict[field]]
return TrainingConfig(**config_dict)
@@ -219,7 +228,7 @@ def load(file_path: Union[str, Path],
)
_LOGGER.error(msg)
try:
return TrainingConfig.from_dict(**config)
return TrainingConfig.from_dict(config)
except TypeError as e:
msg = (
f"Error when creating an instance of {TrainingConfig} "

View File

@@ -5,7 +5,7 @@ import csv
import logging
from datetime import datetime
from pathlib import Path
from typing import Dict, Tuple, Union
from typing import Dict, Tuple, Union, Final
import networkx as nx
import numpy as np
@@ -77,6 +77,8 @@ class Primaite(Env):
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
self.session_path: Final[Path] = session_path
self.timestamp_str: Final[str] = timestamp_str
self._training_config_path = training_config_path
self._lay_down_config_path = lay_down_config_path
@@ -93,7 +95,7 @@ class Primaite(Env):
self.transaction_list = transaction_list
# The agent in use
self.agent_identifier = self.training_config.agent_identifier
self.agent_identifier = self.training_config.red_agent_identifier
# Create a dictionary to hold all the nodes
self.nodes: Dict[str, NodeUnion] = {}

View File

@@ -108,54 +108,6 @@ def run_stable_baselines3_ppo(
env.close()
def run_stable_baselines3_a2c(
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
):
"""
Run against a stable_baselines3 A2C agent.
:param env: An instance of
:class:`~primaite.environment.primaite_env.Primaite`.
:param config_values: An instance of
:class:`~primaite.config.training_config.TrainingConfig`.
param session_path: The directory path the session is writing to.
:param timestamp_str: The session timestamp in the format:
<yyyy-mm-dd>_<hh-mm-ss>.
"""
if config_values.load_agent:
try:
agent = A2C.load(
config_values.agent_load_file,
env,
verbose=0,
n_steps=config_values.num_steps,
)
except Exception:
print(
"ERROR: Could not load agent at location: "
+ config_values.agent_load_file
)
_LOGGER.error("Could not load agent")
_LOGGER.error("Exception occured", exc_info=True)
else:
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
if config_values.session_type == "TRAINING":
# We're in a training session
print("Starting training session...")
_LOGGER.debug("Starting training session...")
for episode in range(config_values.num_episodes):
agent.learn(total_timesteps=config_values.num_steps)
_save_agent(agent, session_path, timestamp_str)
else:
# Default to being in an evaluation session
print("Starting evaluation session...")
_LOGGER.debug("Starting evaluation session...")
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
env.close()
def _write_session_metadata_file(
session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite
):