#917 - Got RLlib fully training in PrimAITE. Started integrating the the other agents into the Session class
This commit is contained in:
@@ -1 +1 @@
|
||||
2.0.0dev0
|
||||
2.0.0rc1
|
||||
|
||||
@@ -1,36 +1,84 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Final, Dict, Any
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, Final, Dict, Any, Union, Tuple
|
||||
|
||||
import yaml
|
||||
|
||||
from primaite import getLogger
|
||||
from primaite.config.training_config import TrainingConfig
|
||||
from primaite.config.training_config import TrainingConfig, load
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
_LOGGER = getLogger(__name__)
|
||||
|
||||
|
||||
def _get_temp_session_path(session_timestamp: datetime) -> Path:
|
||||
"""
|
||||
Get a temp directory session path the test session will output to.
|
||||
|
||||
:param session_timestamp: This is the datetime that the session started.
|
||||
:return: The session directory path.
|
||||
"""
|
||||
date_dir = session_timestamp.strftime("%Y-%m-%d")
|
||||
session_dir = session_timestamp.strftime("%Y-%m-%d_%H-%M-%S")
|
||||
session_path = Path("./") / date_dir / session_dir
|
||||
session_path.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
return session_path
|
||||
|
||||
|
||||
class AgentABC(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, env: Primaite):
|
||||
self._env: Primaite = env
|
||||
self._training_config: Final[TrainingConfig] = self._env.training_config
|
||||
self._lay_down_config: Dict[str, Any] = self._env.lay_down_config
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
self._training_config_path = training_config_path
|
||||
self._training_config: Final[TrainingConfig] = load(
|
||||
self._training_config_path
|
||||
)
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
self.session_timestamp: datetime = datetime.now()
|
||||
self.session_path = _get_temp_session_path(self.session_timestamp)
|
||||
|
||||
self.timestamp_str = self.session_timestamp.strftime(
|
||||
"%Y-%m-%d_%H-%M-%S")
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
def _save_checkpoint(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -44,14 +92,24 @@ class AgentABC(ABC):
|
||||
|
||||
class DeterministicAgentABC(AgentABC):
|
||||
@abstractmethod
|
||||
def __init__(self, env: Primaite):
|
||||
self._env: Primaite = env
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
self._env: Primaite
|
||||
self._agent = None
|
||||
|
||||
@abstractmethod
|
||||
def _setup(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def _get_latest_checkpoint(self):
|
||||
pass
|
||||
|
||||
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
_LOGGER.warning("Deterministic agents cannot learn")
|
||||
@@ -60,8 +118,9 @@ class DeterministicAgentABC(AgentABC):
|
||||
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(self):
|
||||
def load(cls):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -8,170 +8,106 @@ from ray.rllib.algorithms import Algorithm
|
||||
from ray.rllib.algorithms.ppo import PPOConfig
|
||||
from ray.tune.registry import register_env
|
||||
|
||||
from primaite.agents.agent_abc import AgentABC
|
||||
from primaite.config import training_config
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
|
||||
|
||||
class DLFramework(Enum):
|
||||
"""The DL Frameworks enumeration."""
|
||||
TF = "tf"
|
||||
TF2 = "tf2"
|
||||
TORCH = "torch"
|
||||
def _env_creator(env_config):
|
||||
return Primaite(
|
||||
training_config_path=env_config["training_config_path"],
|
||||
lay_down_config_path=env_config["lay_down_config_path"],
|
||||
transaction_list=env_config["transaction_list"],
|
||||
session_path=env_config["session_path"],
|
||||
timestamp_str=env_config["timestamp_str"]
|
||||
)
|
||||
|
||||
|
||||
def env_creator(env_config):
|
||||
training_config_path = env_config["training_config_path"]
|
||||
lay_down_config_path = env_config["lay_down_config_path"]
|
||||
return Primaite(training_config_path, lay_down_config_path, [])
|
||||
class RLlibPPO(AgentABC):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._ppo_config: PPOConfig
|
||||
self._current_result: dict
|
||||
self._setup()
|
||||
|
||||
def get_ppo_config(
|
||||
training_config_path: Union[str, Path],
|
||||
lay_down_config_path: Union[str, Path],
|
||||
framework: Optional[DLFramework] = DLFramework.TORCH
|
||||
) -> PPOConfig():
|
||||
# Register environment
|
||||
register_env("primaite", env_creator)
|
||||
def _setup(self):
|
||||
register_env("primaite", _env_creator)
|
||||
self._ppo_config = PPOConfig()
|
||||
|
||||
# Setup PPO
|
||||
config = PPOConfig()
|
||||
|
||||
config_values = training_config.load(training_config_path)
|
||||
|
||||
# Setup our config object to use our environment
|
||||
config.environment(
|
||||
env="primaite",
|
||||
env_config=dict(
|
||||
training_config_path=training_config_path,
|
||||
lay_down_config_path=lay_down_config_path
|
||||
self._ppo_config.environment(
|
||||
env="primaite",
|
||||
env_config=dict(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
transaction_list=[],
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
env_config = config_values
|
||||
action_type = env_config.action_type
|
||||
red_agent = env_config.red_agent_identifier
|
||||
self._ppo_config.training(
|
||||
train_batch_size=self._training_config.num_steps
|
||||
)
|
||||
self._ppo_config.framework(
|
||||
framework=self._training_config.deep_learning_framework.value
|
||||
)
|
||||
|
||||
if red_agent == "RANDOM" and action_type == "NODE":
|
||||
config.training(
|
||||
train_batch_size=6000, lr=5e-5
|
||||
) # number of steps in a training iteration
|
||||
elif red_agent == "RANDOM" and action_type != "NODE":
|
||||
config.training(train_batch_size=6000, lr=5e-5)
|
||||
elif red_agent == "CONFIG" and action_type == "NODE":
|
||||
config.training(train_batch_size=400, lr=5e-5)
|
||||
elif red_agent == "CONFIG" and action_type != "NONE":
|
||||
config.training(train_batch_size=500, lr=5e-5)
|
||||
else:
|
||||
config.training(train_batch_size=500, lr=5e-5)
|
||||
self._ppo_config.rollouts(
|
||||
num_rollout_workers=1,
|
||||
num_envs_per_worker=1,
|
||||
horizon=self._training_config.num_steps
|
||||
)
|
||||
self._agent: Algorithm = self._ppo_config.build()
|
||||
|
||||
# Decide if you want torch or tensorflow DL framework. Default is "tf"
|
||||
config.framework(framework=framework.value)
|
||||
def _save_checkpoint(self):
|
||||
checkpoint_n = self._training_config.checkpoint_every_n_episodes
|
||||
episode_count = self._current_result["episodes_total"]
|
||||
if checkpoint_n > 0 and episode_count > 0:
|
||||
if (
|
||||
(episode_count % checkpoint_n == 0)
|
||||
or (episode_count == self._training_config.num_episodes)
|
||||
):
|
||||
self._agent.save(self.session_path)
|
||||
|
||||
# Set the log level to DEBUG, INFO, WARN, or ERROR
|
||||
config.debugging(seed=415, log_level="ERROR")
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
# Temporarily override train_batch_size and horizon
|
||||
if time_steps:
|
||||
self._ppo_config.train_batch_size = time_steps
|
||||
self._ppo_config.horizon = time_steps
|
||||
|
||||
# Setup evaluation
|
||||
# Explicitly set "explore"=False to override default
|
||||
# config.evaluation(
|
||||
# evaluation_interval=100,
|
||||
# evaluation_duration=20,
|
||||
# # evaluation_duration_unit="timesteps",) #default episodes
|
||||
# evaluation_config={"explore": False},
|
||||
# )
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
# Setup sampling rollout workers
|
||||
config.rollouts(
|
||||
num_rollout_workers=4,
|
||||
num_envs_per_worker=1,
|
||||
horizon=128, # num parralel workiers
|
||||
) # max num steps in an episode
|
||||
for i in range(episodes):
|
||||
self._current_result = self._agent.train()
|
||||
self._save_checkpoint()
|
||||
self._agent.stop()
|
||||
|
||||
config.build() # Build config
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
return config
|
||||
def _get_latest_checkpoint(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def load(cls):
|
||||
raise NotImplementedError
|
||||
|
||||
def train(
|
||||
num_iterations: int,
|
||||
config: Optional[PPOConfig] = None,
|
||||
algo: Optional[Algorithm] = None
|
||||
):
|
||||
"""
|
||||
def save(self):
|
||||
raise NotImplementedError
|
||||
|
||||
Requires either the algorithm config (new model) or the algorithm itself (continue training from checkpoint)
|
||||
"""
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
if algo is None:
|
||||
algo = config.build()
|
||||
elif config is None:
|
||||
config = algo.get_config()
|
||||
|
||||
print(f"Algorithm type: {type(algo)}")
|
||||
|
||||
# iterations are not the same as episodes.
|
||||
for i in range(num_iterations):
|
||||
result = algo.train()
|
||||
# # Save every 10 iterations or after last iteration in training
|
||||
# if (i % 100 == 0) or (i == num_iterations - 1):
|
||||
print(
|
||||
f"Iteration={i}, Mean Reward={result['episode_reward_mean']:.2f}")
|
||||
# save checkpoint file
|
||||
checkpoint_file = algo.save("./")
|
||||
print(f"Checkpoint saved at {checkpoint_file}")
|
||||
|
||||
# convert num_iterations to num_episodes
|
||||
num_episodes = len(
|
||||
result["hist_stats"]["episode_lengths"]) * num_iterations
|
||||
# convert num_iterations to num_timesteps
|
||||
num_timesteps = sum(
|
||||
result["hist_stats"]["episode_lengths"] * num_iterations)
|
||||
# calculate number of wins
|
||||
|
||||
# train time
|
||||
print(f"Training took {time.time() - start_time:.2f} seconds")
|
||||
print(
|
||||
f"Number of episodes {num_episodes}, Number of timesteps: {num_timesteps}")
|
||||
return result
|
||||
|
||||
|
||||
def load_model_from_checkpoint(config, checkpoint=None):
|
||||
# create an empty Algorithm
|
||||
algo = config.build()
|
||||
|
||||
if checkpoint is None:
|
||||
# Get the checkpoint with the highest iteration number
|
||||
checkpoint = get_most_recent_checkpoint(config)
|
||||
|
||||
# restore the agent from the checkpoint
|
||||
algo.restore(checkpoint)
|
||||
|
||||
return algo
|
||||
|
||||
|
||||
def get_most_recent_checkpoint(config):
|
||||
"""
|
||||
Get the most recent checkpoint for specified action type, red agent and algorithm
|
||||
"""
|
||||
|
||||
env_config = list(config.env_config.values())[0]
|
||||
action_type = env_config.action_type
|
||||
red_agent = env_config.red_agent_identifier
|
||||
algo_name = config.algo_class.__name__
|
||||
|
||||
# Gets the latest checkpoint (highest iteration not datetime) to use as the final trained model
|
||||
relevant_checkpoints = glob.glob(
|
||||
f"/app/outputs/agents/{action_type}/{red_agent}/{algo_name}/*"
|
||||
)
|
||||
checkpoint_numbers = [int(i.split("_")[1]) for i in relevant_checkpoints]
|
||||
max_checkpoint = str(max(checkpoint_numbers))
|
||||
checkpoint_number_to_use = "0" * (6 - len(max_checkpoint)) + max_checkpoint
|
||||
checkpoint = (
|
||||
relevant_checkpoints[0].split("_")[0]
|
||||
+ "_"
|
||||
+ checkpoint_number_to_use
|
||||
+ "/rllib_checkpoint.json"
|
||||
)
|
||||
|
||||
return checkpoint
|
||||
def export(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -6,30 +6,74 @@ from primaite.agents.agent_abc import AgentABC
|
||||
from primaite.environment.primaite_env import Primaite
|
||||
from stable_baselines3.ppo import MlpPolicy as PPOMlp
|
||||
|
||||
|
||||
class SB3PPO(AgentABC):
|
||||
def __init__(self, env: Primaite):
|
||||
super().__init__(env)
|
||||
def __init__(
|
||||
self,
|
||||
training_config_path,
|
||||
lay_down_config_path
|
||||
):
|
||||
super().__init__(training_config_path, lay_down_config_path)
|
||||
self._tensorboard_log_path = self.session_path / "tensorboard_logs"
|
||||
self._tensorboard_log_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def _setup(self):
|
||||
self._env = Primaite(
|
||||
training_config_path=self._training_config_path,
|
||||
lay_down_config_path=self._lay_down_config_path,
|
||||
transaction_list=[],
|
||||
session_path=self.session_path,
|
||||
timestamp_str=self.timestamp_str
|
||||
)
|
||||
self._agent = PPO(
|
||||
PPOMlp,
|
||||
self._env,
|
||||
verbose=0,
|
||||
n_steps=self._training_config.num_steps
|
||||
n_steps=self._training_config.num_steps,
|
||||
tensorboard_log=self._tensorboard_log_path
|
||||
)
|
||||
|
||||
def learn(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None
|
||||
):
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
def learn(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
def evaluate(self, time_steps: Optional[int], episodes: Optional[int]):
|
||||
pass
|
||||
for i in range(episodes):
|
||||
self._agent.learn(total_timesteps=time_steps)
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
time_steps: Optional[int] = None,
|
||||
episodes: Optional[int] = None,
|
||||
deterministic: bool = True
|
||||
):
|
||||
if not time_steps:
|
||||
time_steps = self._training_config.num_steps
|
||||
|
||||
if not episodes:
|
||||
episodes = self._training_config.num_episodes
|
||||
|
||||
for episode in range(episodes):
|
||||
obs = self._env.reset()
|
||||
|
||||
for step in range(time_steps):
|
||||
action, _states = self._agent.predict(
|
||||
obs,
|
||||
deterministic=deterministic
|
||||
)
|
||||
obs, rewards, done, info = self._env.step(action)
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
def export(self):
|
||||
pass
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -95,15 +95,32 @@ class VerboseLevel(Enum):
|
||||
|
||||
class AgentFramework(Enum):
|
||||
NONE = 0
|
||||
"Custom Agent"
|
||||
SB3 = 1
|
||||
"Stable Baselines3"
|
||||
RLLIB = 2
|
||||
"Ray RLlib"
|
||||
|
||||
|
||||
class DeepLearningFramework(Enum):
|
||||
"""The deep learning framework enumeration."""
|
||||
TF = "tf"
|
||||
"Tensorflow"
|
||||
TF2 = "tf2"
|
||||
"Tensorflow 2.x"
|
||||
TORCH = "torch"
|
||||
"PyTorch"
|
||||
|
||||
|
||||
class RedAgentIdentifier(Enum):
|
||||
A2C = 1
|
||||
"Advantage Actor Critic"
|
||||
PPO = 2
|
||||
"Proximal Policy Optimization"
|
||||
HARDCODED = 3
|
||||
"Custom Agent"
|
||||
RANDOM = 4
|
||||
"Custom Agent"
|
||||
|
||||
|
||||
class ActionType(Enum):
|
||||
|
||||
@@ -1,26 +1,52 @@
|
||||
# Main Config File
|
||||
|
||||
# Generic config values
|
||||
# Choose one of these (dependent on Agent being trained)
|
||||
# "STABLE_BASELINES3_PPO"
|
||||
# "STABLE_BASELINES3_A2C"
|
||||
# "GENERIC"
|
||||
agent_identifier: STABLE_BASELINES3_A2C
|
||||
# Sets which agent algorithm framework will be used:
|
||||
# Options are:
|
||||
# "SB3" (Stable Baselines3)
|
||||
# "RLLIB" (Ray RLlib)
|
||||
# "NONE" (Custom Agent)
|
||||
agent_framework: RLLIB
|
||||
|
||||
# Sets which deep learning framework will be used. Default is TF (Tensorflow).
|
||||
# Options are:
|
||||
# "TF" (Tensorflow)
|
||||
# TF2 (Tensorflow 2.X)
|
||||
# TORCH (PyTorch)
|
||||
deep_learning_framework: TORCH
|
||||
|
||||
# Sets which Red Agent algo/class will be used:
|
||||
# Options are:
|
||||
# "A2C" (Advantage Actor Critic)
|
||||
# "PPO" (Proximal Policy Optimization)
|
||||
# "HARDCODED" (Custom Agent)
|
||||
# "RANDOM" (Random Action)
|
||||
red_agent_identifier: PPO
|
||||
|
||||
# Sets How the Action Space is defined:
|
||||
# "NODE"
|
||||
# "ACL"
|
||||
# "ANY" node and acl actions
|
||||
action_type: NODE
|
||||
|
||||
# Number of episodes to run per session
|
||||
num_episodes: 10
|
||||
|
||||
# Number of time_steps per episode
|
||||
num_steps: 256
|
||||
|
||||
# Sets how often the agent will save a checkpoint (every n time episodes).
|
||||
# Set to 0 if no checkpoints are required. Default is 10
|
||||
checkpoint_every_n_episodes: 5
|
||||
|
||||
# Time delay between steps (for generic agents)
|
||||
time_delay: 10
|
||||
|
||||
# Type of session to be run (TRAINING or EVALUATION)
|
||||
session_type: TRAINING
|
||||
|
||||
# Determine whether to load an agent from file
|
||||
load_agent: False
|
||||
|
||||
# File path and file name of agent if you're loading one in
|
||||
agent_load_file: C:\[Path]\[agent_saved_filename.zip]
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import Any, Dict, Final, Union, Optional
|
||||
import yaml
|
||||
|
||||
from primaite import USERS_CONFIG_DIR, getLogger
|
||||
from primaite.common.enums import DeepLearningFramework
|
||||
from primaite.common.enums import ActionType, RedAgentIdentifier, \
|
||||
AgentFramework, SessionType
|
||||
|
||||
@@ -20,10 +21,13 @@ _EXAMPLE_TRAINING: Final[Path] = USERS_CONFIG_DIR / "example_config" / "training
|
||||
class TrainingConfig:
|
||||
"""The Training Config class."""
|
||||
agent_framework: AgentFramework = AgentFramework.SB3
|
||||
"The agent framework."
|
||||
"The AgentFramework"
|
||||
|
||||
deep_learning_framework: DeepLearningFramework = DeepLearningFramework.TF
|
||||
"The DeepLearningFramework."
|
||||
|
||||
red_agent_identifier: RedAgentIdentifier = RedAgentIdentifier.PPO
|
||||
"The red agent/algo class."
|
||||
"The RedAgentIdentifier.."
|
||||
|
||||
action_type: ActionType = ActionType.ANY
|
||||
"The ActionType to use."
|
||||
@@ -33,6 +37,10 @@ class TrainingConfig:
|
||||
|
||||
num_steps: int = 256
|
||||
"The number of steps in an episode."
|
||||
|
||||
checkpoint_every_n_episodes: int = 5
|
||||
"The agent will save a checkpoint every n episodes."
|
||||
|
||||
observation_space: dict = field(
|
||||
default_factory=lambda: {"components": [{"name": "NODE_LINK_TABLE"}]}
|
||||
)
|
||||
@@ -148,6 +156,7 @@ class TrainingConfig:
|
||||
) -> TrainingConfig:
|
||||
field_enum_map = {
|
||||
"agent_framework": AgentFramework,
|
||||
"deep_learning_framework": DeepLearningFramework,
|
||||
"red_agent_identifier": RedAgentIdentifier,
|
||||
"action_type": ActionType,
|
||||
"session_type": SessionType
|
||||
@@ -155,7 +164,7 @@ class TrainingConfig:
|
||||
|
||||
for field, enum_class in field_enum_map.items():
|
||||
if field in config_dict:
|
||||
config_dict[field] = enum_class[field]
|
||||
config_dict[field] = enum_class[config_dict[field]]
|
||||
|
||||
return TrainingConfig(**config_dict)
|
||||
|
||||
@@ -219,7 +228,7 @@ def load(file_path: Union[str, Path],
|
||||
)
|
||||
_LOGGER.error(msg)
|
||||
try:
|
||||
return TrainingConfig.from_dict(**config)
|
||||
return TrainingConfig.from_dict(config)
|
||||
except TypeError as e:
|
||||
msg = (
|
||||
f"Error when creating an instance of {TrainingConfig} "
|
||||
|
||||
@@ -5,7 +5,7 @@ import csv
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Union
|
||||
from typing import Dict, Tuple, Union, Final
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
@@ -77,6 +77,8 @@ class Primaite(Env):
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
self.session_path: Final[Path] = session_path
|
||||
self.timestamp_str: Final[str] = timestamp_str
|
||||
self._training_config_path = training_config_path
|
||||
self._lay_down_config_path = lay_down_config_path
|
||||
|
||||
@@ -93,7 +95,7 @@ class Primaite(Env):
|
||||
self.transaction_list = transaction_list
|
||||
|
||||
# The agent in use
|
||||
self.agent_identifier = self.training_config.agent_identifier
|
||||
self.agent_identifier = self.training_config.red_agent_identifier
|
||||
|
||||
# Create a dictionary to hold all the nodes
|
||||
self.nodes: Dict[str, NodeUnion] = {}
|
||||
|
||||
@@ -108,54 +108,6 @@ def run_stable_baselines3_ppo(
|
||||
env.close()
|
||||
|
||||
|
||||
def run_stable_baselines3_a2c(
|
||||
env: Primaite, config_values: TrainingConfig, session_path: Path, timestamp_str: str
|
||||
):
|
||||
"""
|
||||
Run against a stable_baselines3 A2C agent.
|
||||
|
||||
:param env: An instance of
|
||||
:class:`~primaite.environment.primaite_env.Primaite`.
|
||||
:param config_values: An instance of
|
||||
:class:`~primaite.config.training_config.TrainingConfig`.
|
||||
param session_path: The directory path the session is writing to.
|
||||
:param timestamp_str: The session timestamp in the format:
|
||||
<yyyy-mm-dd>_<hh-mm-ss>.
|
||||
"""
|
||||
if config_values.load_agent:
|
||||
try:
|
||||
agent = A2C.load(
|
||||
config_values.agent_load_file,
|
||||
env,
|
||||
verbose=0,
|
||||
n_steps=config_values.num_steps,
|
||||
)
|
||||
except Exception:
|
||||
print(
|
||||
"ERROR: Could not load agent at location: "
|
||||
+ config_values.agent_load_file
|
||||
)
|
||||
_LOGGER.error("Could not load agent")
|
||||
_LOGGER.error("Exception occured", exc_info=True)
|
||||
else:
|
||||
agent = A2C("MlpPolicy", env, verbose=0, n_steps=config_values.num_steps)
|
||||
|
||||
if config_values.session_type == "TRAINING":
|
||||
# We're in a training session
|
||||
print("Starting training session...")
|
||||
_LOGGER.debug("Starting training session...")
|
||||
for episode in range(config_values.num_episodes):
|
||||
agent.learn(total_timesteps=config_values.num_steps)
|
||||
_save_agent(agent, session_path, timestamp_str)
|
||||
else:
|
||||
# Default to being in an evaluation session
|
||||
print("Starting evaluation session...")
|
||||
_LOGGER.debug("Starting evaluation session...")
|
||||
evaluate_policy(agent, env, n_eval_episodes=config_values.num_episodes)
|
||||
|
||||
env.close()
|
||||
|
||||
|
||||
def _write_session_metadata_file(
|
||||
session_dir: Path, uuid: str, session_timestamp: datetime, env: Primaite
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user